fix generation

This commit is contained in:
hiyouga 2023-08-16 22:39:54 +08:00
parent 7407d9daa1
commit d9e62711a3
3 changed files with 6 additions and 0 deletions

View File

@ -49,6 +49,8 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
)) ))

View File

@ -74,6 +74,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict() gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)

View File

@ -52,6 +52,8 @@ def run_sft(
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids) gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)