update gradio, support multiple resp in api
This commit is contained in:
parent
d8cf8cfdeb
commit
84af10cec9
|
@ -12,7 +12,7 @@ fire
|
|||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
gradio==3.38.0
|
||||
gradio==3.50.2
|
||||
uvicorn
|
||||
pydantic==1.10.11
|
||||
fastapi==0.95.1
|
||||
|
|
|
@ -6,8 +6,8 @@ from llmtuner import ChatModel, create_app
|
|||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
print("Visit http://localhost:8000/docs for API document.")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat import ChatModel
|
||||
|
@ -29,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
|||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
try:
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except:
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
@ -45,10 +54,10 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
|
@ -62,6 +71,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, system, request)
|
||||
|
@ -72,7 +83,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n
|
||||
)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
|
@ -81,13 +93,13 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
total_tokens=prompt_length+response_length
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||
choices = [ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
) for i, choice in enumerate(response)]
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
|
@ -96,7 +108,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, system,
|
||||
|
@ -114,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
|
@ -122,7 +134,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||
finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
return app
|
||||
|
|
|
@ -20,9 +20,6 @@ class ModelCard(BaseModel):
|
|||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = []
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
|
|
|
@ -26,17 +26,17 @@ class ChatModel:
|
|||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
system = system or self.system_prompt
|
||||
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
prompt_length = len(input_ids[0])
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
|
@ -47,11 +47,15 @@ class ChatModel:
|
|||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
if int(num_return_sequences) > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
@ -75,12 +79,16 @@ class ChatModel:
|
|||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
) -> Tuple[List[str], Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][prompt_length:]
|
||||
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
response_length = len(outputs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
response_length = 0
|
||||
for i in range(len(response_ids)):
|
||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
||||
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -83,7 +83,7 @@ def load_model_and_tokenizer(
|
|||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2)
|
||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
|
@ -193,7 +193,7 @@ def load_model_and_tokenizer(
|
|||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
|
||||
|
@ -212,8 +212,7 @@ def load_model_and_tokenizer(
|
|||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
model._keys_to_ignore_on_save = None
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
|
|
|
@ -180,15 +180,15 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
response_length = response_index[-1] + 2 # save the EOS token
|
||||
response_length = response_index[-1].item() + 2 # save the EOS token
|
||||
else:
|
||||
response_length = response_index[-1] + 1
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
@ -216,7 +216,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
@ -266,7 +266,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
for j in range(len(query_batch)):
|
||||
start = len(query_batch[j]) - 1
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
start += attention_mask[j, :].nonzero()[0].item()
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
|
|
|
@ -7,13 +7,9 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
|
||||
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
preview_count = gr.Number(interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
|
||||
close_btn = gr.Button()
|
||||
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
|
|
|
@ -6,10 +6,10 @@ CSS = r"""
|
|||
transform: translate(-50%, -50%); /* center horizontally */
|
||||
max-width: 1000px;
|
||||
max-height: 750px;
|
||||
overflow-y: scroll !important;
|
||||
background-color: var(--input-background-fill);
|
||||
border: 2px solid black !important;
|
||||
z-index: 1000;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.dark .modal-box {
|
||||
|
|
|
@ -14,7 +14,7 @@ from llmtuner.webui.css import CSS
|
|||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
|
||||
require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
|
|
Loading…
Reference in New Issue