fix gen args

This commit is contained in:
hiyouga 2024-05-15 01:49:05 +08:00
parent 7e69e71a52
commit 144801db09
2 changed files with 36 additions and 38 deletions

View File

@ -65,12 +65,13 @@ class HuggingfaceEngine(BaseEngine):
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
do_sample = input_kwargs.pop("do_sample", None) do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
temperature = input_kwargs.pop("temperature", None) temperature = input_kwargs.pop("temperature", generating_args["temperature"])
top_p = input_kwargs.pop("top_p", None) top_p = input_kwargs.pop("top_p", generating_args["top_p"])
top_k = input_kwargs.pop("top_k", None) top_k = input_kwargs.pop("top_k", generating_args["top_k"])
num_return_sequences = input_kwargs.pop("num_return_sequences", None) num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", None) repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop = input_kwargs.pop("stop", None)
@ -78,14 +79,16 @@ class HuggingfaceEngine(BaseEngine):
if stop is not None: if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.") raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy()
generating_args.update( generating_args.update(
dict( dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"], do_sample=do_sample,
temperature=temperature or generating_args["temperature"], temperature=temperature,
top_p=top_p or generating_args["top_p"], top_p=top_p,
top_k=top_k or generating_args["top_k"], top_k=top_k,
num_return_sequences=num_return_sequences or 1, num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
) )
@ -94,6 +97,10 @@ class HuggingfaceEngine(BaseEngine):
if isinstance(num_return_sequences, int) and num_return_sequences > 1: if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True generating_args["do_sample"] = True
if not generating_args["do_sample"]:
generating_args.pop("temperature", None)
generating_args.pop("top_p", None)
if max_length: if max_length:
generating_args.pop("max_new_tokens", None) generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length generating_args["max_length"] = max_length

View File

@ -89,43 +89,34 @@ class VllmEngine(BaseEngine):
) )
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
temperature = input_kwargs.pop("temperature", None) use_beam_search = self.generating_args["num_beams"] > 1
top_p = input_kwargs.pop("top_p", None) temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
top_k = input_kwargs.pop("top_k", None) top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
num_return_sequences = input_kwargs.pop("num_return_sequences", None) top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
repetition_penalty = input_kwargs.pop("repetition_penalty", None) num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop = input_kwargs.pop("stop", None)
generating_args = self.generating_args.copy()
generating_args.update(
dict(
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
)
)
if max_length: if max_length:
generating_args["max_new_tokens"] = max_length - prompt_length max_tokens = max_length - prompt_length
if max_new_tokens: if max_new_tokens:
generating_args["max_new_tokens"] = max_new_tokens max_tokens = max_new_tokens
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=generating_args["num_return_sequences"], n=num_return_sequences,
repetition_penalty=generating_args["repetition_penalty"], repetition_penalty=repetition_penalty,
temperature=generating_args["temperature"], temperature=temperature,
top_p=generating_args["top_p"], top_p=top_p,
top_k=generating_args["top_k"], top_k=top_k,
use_beam_search=generating_args["num_beams"] > 1, use_beam_search=use_beam_search,
length_penalty=generating_args["length_penalty"], length_penalty=length_penalty,
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"], max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )