forked from p04798526/LLaMA-Factory-Mirror
update api
This commit is contained in:
parent
d21cc71750
commit
f030b09924
|
@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
"input_ids": inputs["input_ids"],
|
||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
|
||||
"logits_processor": get_logits_processor()
|
||||
})
|
||||
if request.max_length:
|
||||
gen_kwargs.pop("max_new_tokens", None)
|
||||
gen_kwargs["max_length"] = request.max_length
|
||||
if request.max_new_tokens:
|
||||
gen_kwargs.pop("max_length", None)
|
||||
gen_kwargs["max_new_tokens"] = request.max_new_tokens
|
||||
|
||||
if request.stream:
|
||||
generate = predict(gen_kwargs, request.model)
|
||||
|
|
|
@ -171,8 +171,8 @@ def load_pretrained(
|
|||
padding_side="left",
|
||||
**config_kwargs
|
||||
)
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
|
|
|
@ -277,6 +277,10 @@ class GeneratingArguments:
|
|||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||
|
@ -291,4 +295,7 @@ class GeneratingArguments:
|
|||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", None):
|
||||
args.pop("max_length", None)
|
||||
return args
|
||||
|
|
Loading…
Reference in New Issue