fix gen args
This commit is contained in:
parent
7e69e71a52
commit
144801db09
|
@ -65,12 +65,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
|
|
||||||
do_sample = input_kwargs.pop("do_sample", None)
|
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
temperature = input_kwargs.pop("temperature", generating_args["temperature"])
|
||||||
top_p = input_kwargs.pop("top_p", None)
|
top_p = input_kwargs.pop("top_p", generating_args["top_p"])
|
||||||
top_k = input_kwargs.pop("top_k", None)
|
top_k = input_kwargs.pop("top_k", generating_args["top_k"])
|
||||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
|
||||||
|
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
stop = input_kwargs.pop("stop", None)
|
stop = input_kwargs.pop("stop", None)
|
||||||
|
@ -78,14 +79,16 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
||||||
|
|
||||||
|
generating_args = generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
dict(
|
dict(
|
||||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
do_sample=do_sample,
|
||||||
temperature=temperature or generating_args["temperature"],
|
temperature=temperature,
|
||||||
top_p=top_p or generating_args["top_p"],
|
top_p=top_p,
|
||||||
top_k=top_k or generating_args["top_k"],
|
top_k=top_k,
|
||||||
num_return_sequences=num_return_sequences or 1,
|
num_return_sequences=num_return_sequences,
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
repetition_penalty=repetition_penalty,
|
||||||
|
length_penalty=length_penalty,
|
||||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
@ -94,6 +97,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||||
generating_args["do_sample"] = True
|
generating_args["do_sample"] = True
|
||||||
|
|
||||||
|
if not generating_args["do_sample"]:
|
||||||
|
generating_args.pop("temperature", None)
|
||||||
|
generating_args.pop("top_p", None)
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
generating_args.pop("max_new_tokens", None)
|
generating_args.pop("max_new_tokens", None)
|
||||||
generating_args["max_length"] = max_length
|
generating_args["max_length"] = max_length
|
||||||
|
|
|
@ -89,43 +89,34 @@ class VllmEngine(BaseEngine):
|
||||||
)
|
)
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
use_beam_search = self.generating_args["num_beams"] > 1
|
||||||
top_p = input_kwargs.pop("top_p", None)
|
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
|
||||||
top_k = input_kwargs.pop("top_k", None)
|
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
|
||||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
|
||||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
|
||||||
|
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
stop = input_kwargs.pop("stop", None)
|
stop = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
generating_args = self.generating_args.copy()
|
|
||||||
generating_args.update(
|
|
||||||
dict(
|
|
||||||
temperature=temperature or generating_args["temperature"],
|
|
||||||
top_p=top_p or generating_args["top_p"],
|
|
||||||
top_k=top_k or generating_args["top_k"],
|
|
||||||
num_return_sequences=num_return_sequences or 1,
|
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
generating_args["max_new_tokens"] = max_length - prompt_length
|
max_tokens = max_length - prompt_length
|
||||||
|
|
||||||
if max_new_tokens:
|
if max_new_tokens:
|
||||||
generating_args["max_new_tokens"] = max_new_tokens
|
max_tokens = max_new_tokens
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=generating_args["num_return_sequences"],
|
n=num_return_sequences,
|
||||||
repetition_penalty=generating_args["repetition_penalty"],
|
repetition_penalty=repetition_penalty,
|
||||||
temperature=generating_args["temperature"],
|
temperature=temperature,
|
||||||
top_p=generating_args["top_p"],
|
top_p=top_p,
|
||||||
top_k=generating_args["top_k"],
|
top_k=top_k,
|
||||||
use_beam_search=generating_args["num_beams"] > 1,
|
use_beam_search=use_beam_search,
|
||||||
length_penalty=generating_args["length_penalty"],
|
length_penalty=length_penalty,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
max_tokens=generating_args["max_new_tokens"],
|
max_tokens=max_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue