From 144801db09ec7f183ab455d7a88c76de7639333d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 15 May 2024 01:49:05 +0800 Subject: [PATCH] fix gen args --- src/llmtuner/chat/hf_engine.py | 31 ++++++++++++++--------- src/llmtuner/chat/vllm_engine.py | 43 +++++++++++++------------------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/llmtuner/chat/hf_engine.py b/src/llmtuner/chat/hf_engine.py index 97160d57..5cb8bfe4 100644 --- a/src/llmtuner/chat/hf_engine.py +++ b/src/llmtuner/chat/hf_engine.py @@ -65,12 +65,13 @@ class HuggingfaceEngine(BaseEngine): prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) - do_sample = input_kwargs.pop("do_sample", None) - temperature = input_kwargs.pop("temperature", None) - top_p = input_kwargs.pop("top_p", None) - top_k = input_kwargs.pop("top_k", None) - num_return_sequences = input_kwargs.pop("num_return_sequences", None) - repetition_penalty = input_kwargs.pop("repetition_penalty", None) + do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"]) + temperature = input_kwargs.pop("temperature", generating_args["temperature"]) + top_p = input_kwargs.pop("top_p", generating_args["top_p"]) + top_k = input_kwargs.pop("top_k", generating_args["top_k"]) + num_return_sequences = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"]) + length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"]) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) stop = input_kwargs.pop("stop", None) @@ -78,14 +79,16 @@ class HuggingfaceEngine(BaseEngine): if stop is not None: raise ValueError("Stop parameter is not supported in Huggingface engine yet.") + generating_args = generating_args.copy() generating_args.update( dict( - do_sample=do_sample if do_sample is not None else generating_args["do_sample"], - temperature=temperature or generating_args["temperature"], - top_p=top_p or generating_args["top_p"], - top_k=top_k or generating_args["top_k"], - num_return_sequences=num_return_sequences or 1, - repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + top_k=top_k, + num_return_sequences=num_return_sequences, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, pad_token_id=tokenizer.pad_token_id, ) @@ -94,6 +97,10 @@ class HuggingfaceEngine(BaseEngine): if isinstance(num_return_sequences, int) and num_return_sequences > 1: generating_args["do_sample"] = True + if not generating_args["do_sample"]: + generating_args.pop("temperature", None) + generating_args.pop("top_p", None) + if max_length: generating_args.pop("max_new_tokens", None) generating_args["max_length"] = max_length diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index d50e41aa..faf8c9fe 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -89,43 +89,34 @@ class VllmEngine(BaseEngine): ) prompt_length = len(prompt_ids) - temperature = input_kwargs.pop("temperature", None) - top_p = input_kwargs.pop("top_p", None) - top_k = input_kwargs.pop("top_k", None) - num_return_sequences = input_kwargs.pop("num_return_sequences", None) - repetition_penalty = input_kwargs.pop("repetition_penalty", None) + use_beam_search = self.generating_args["num_beams"] > 1 + temperature = input_kwargs.pop("temperature", self.generating_args["temperature"]) + top_p = input_kwargs.pop("top_p", self.generating_args["top_p"]) + top_k = input_kwargs.pop("top_k", self.generating_args["top_k"]) + num_return_sequences = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"]) + length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"]) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) stop = input_kwargs.pop("stop", None) - generating_args = self.generating_args.copy() - generating_args.update( - dict( - temperature=temperature or generating_args["temperature"], - top_p=top_p or generating_args["top_p"], - top_k=top_k or generating_args["top_k"], - num_return_sequences=num_return_sequences or 1, - repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], - ) - ) - if max_length: - generating_args["max_new_tokens"] = max_length - prompt_length + max_tokens = max_length - prompt_length if max_new_tokens: - generating_args["max_new_tokens"] = max_new_tokens + max_tokens = max_new_tokens sampling_params = SamplingParams( - n=generating_args["num_return_sequences"], - repetition_penalty=generating_args["repetition_penalty"], - temperature=generating_args["temperature"], - top_p=generating_args["top_p"], - top_k=generating_args["top_k"], - use_beam_search=generating_args["num_beams"] > 1, - length_penalty=generating_args["length_penalty"], + n=num_return_sequences, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + use_beam_search=use_beam_search, + length_penalty=length_penalty, stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, - max_tokens=generating_args["max_new_tokens"], + max_tokens=max_tokens, skip_special_tokens=True, )