fix gen args
This commit is contained in:
parent
7e69e71a52
commit
144801db09
|
@ -65,12 +65,13 @@ class HuggingfaceEngine(BaseEngine):
|
|||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
|
||||
temperature = input_kwargs.pop("temperature", generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
|
@ -78,14 +79,16 @@ class HuggingfaceEngine(BaseEngine):
|
|||
if stop is not None:
|
||||
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
|
@ -94,6 +97,10 @@ class HuggingfaceEngine(BaseEngine):
|
|||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if not generating_args["do_sample"]:
|
||||
generating_args.pop("temperature", None)
|
||||
generating_args.pop("top_p", None)
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
|
|
@ -89,43 +89,34 @@ class VllmEngine(BaseEngine):
|
|||
)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
use_beam_search = self.generating_args["num_beams"] > 1
|
||||
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
|
||||
generating_args = self.generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
)
|
||||
)
|
||||
|
||||
if max_length:
|
||||
generating_args["max_new_tokens"] = max_length - prompt_length
|
||||
max_tokens = max_length - prompt_length
|
||||
|
||||
if max_new_tokens:
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
max_tokens = max_new_tokens
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=generating_args["num_return_sequences"],
|
||||
repetition_penalty=generating_args["repetition_penalty"],
|
||||
temperature=generating_args["temperature"],
|
||||
top_p=generating_args["top_p"],
|
||||
top_k=generating_args["top_k"],
|
||||
use_beam_search=generating_args["num_beams"] > 1,
|
||||
length_penalty=generating_args["length_penalty"],
|
||||
n=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=generating_args["max_new_tokens"],
|
||||
max_tokens=max_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue