From d9e62711a3349d7c6fd3512fb25c709bdfbb311a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 16 Aug 2023 22:39:54 +0800 Subject: [PATCH] fix generation --- src/llmtuner/chat/stream_chat.py | 2 ++ src/llmtuner/tuner/ppo/trainer.py | 2 ++ src/llmtuner/tuner/sft/workflow.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index bf602dd5..00b25b45 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -49,6 +49,8 @@ class ChatModel: top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, logits_processor=get_logits_processor(), stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) )) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index cc73854f..fa4170f6 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -74,6 +74,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Keyword arguments for `model.generate` gen_kwargs = self.generating_args.to_dict() + gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id + gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 10f7aafb..f2d72fc6 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -52,6 +52,8 @@ def run_sft( # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = tokenizer.eos_token_id + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)