alert pad_token source

This commit is contained in:
hiyouga 2023-08-15 00:07:56 +08:00
parent 9d0f6214b6
commit 80b4053602
2 changed files with 5 additions and 2 deletions

View File

@ -204,7 +204,10 @@ def get_template_and_fix_tokenizer(
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))

View File

@ -154,7 +154,7 @@ def load_model_and_tokenizer(
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head"):
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.