update api

This commit is contained in:
hiyouga 2023-06-26 13:39:57 +08:00
parent d21cc71750
commit f030b09924
3 changed files with 17 additions and 4 deletions

View File

@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
max_new_tokens: Optional[int] = None
stream: Optional[bool] = False
@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
"input_ids": inputs["input_ids"],
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
"logits_processor": get_logits_processor()
})
if request.max_length:
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = request.max_length
if request.max_new_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = request.max_new_tokens
if request.stream:
generate = predict(gen_kwargs, request.model)

View File

@ -171,8 +171,8 @@ def load_pretrained(
padding_side="left",
**config_kwargs
)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
tokenizer.pad_token_id = 0 # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True

View File

@ -277,6 +277,10 @@ class GeneratingArguments:
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_length: Optional[int] = field(
default=None,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
)
max_new_tokens: Optional[int] = field(
default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
@ -291,4 +295,7 @@ class GeneratingArguments:
)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
args = asdict(self)
if args.get("max_new_tokens", None):
args.pop("max_length", None)
return args