forked from p04798526/LLaMA-Factory-Mirror
update api
This commit is contained in:
parent
d21cc71750
commit
f030b09924
|
@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
max_length: Optional[int] = None
|
||||||
max_new_tokens: Optional[int] = None
|
max_new_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
"input_ids": inputs["input_ids"],
|
"input_ids": inputs["input_ids"],
|
||||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||||
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
|
|
||||||
"logits_processor": get_logits_processor()
|
"logits_processor": get_logits_processor()
|
||||||
})
|
})
|
||||||
|
if request.max_length:
|
||||||
|
gen_kwargs.pop("max_new_tokens", None)
|
||||||
|
gen_kwargs["max_length"] = request.max_length
|
||||||
|
if request.max_new_tokens:
|
||||||
|
gen_kwargs.pop("max_length", None)
|
||||||
|
gen_kwargs["max_new_tokens"] = request.max_new_tokens
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(gen_kwargs, request.model)
|
generate = predict(gen_kwargs, request.model)
|
||||||
|
|
|
@ -171,8 +171,8 @@ def load_pretrained(
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
|
|
|
@ -277,6 +277,10 @@ class GeneratingArguments:
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||||
)
|
)
|
||||||
|
max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||||
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: Optional[int] = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||||
|
@ -291,4 +295,7 @@ class GeneratingArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
args = asdict(self)
|
||||||
|
if args.get("max_new_tokens", None):
|
||||||
|
args.pop("max_length", None)
|
||||||
|
return args
|
||||||
|
|
Loading…
Reference in New Issue