update gradio, support multiple resp in api

This commit is contained in:
hiyouga 2023-11-01 23:02:16 +08:00
parent d8cf8cfdeb
commit 84af10cec9
10 changed files with 54 additions and 42 deletions

View File

@ -12,7 +12,7 @@ fire
jieba
rouge-chinese
nltk
gradio==3.38.0
gradio==3.50.2
uvicorn
pydantic==1.10.11
fastapi==0.95.1

View File

@ -6,8 +6,8 @@ from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
print("Visit http://localhost:8000/docs for API document.")
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
if __name__ == "__main__":

View File

@ -1,9 +1,11 @@
import json
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from pydantic import BaseModel
from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel
@ -29,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory
torch_gc()
def to_json(data: BaseModel) -> str:
try:
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except:
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: ChatModel) -> FastAPI:
app = FastAPI(lifespan=lifespan)
@ -45,10 +54,10 @@ def create_app(chat_model: ChatModel) -> FastAPI:
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
@ -62,6 +71,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream:
generate = predict(query, history, system, request)
@ -72,7 +83,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens
max_new_tokens=request.max_tokens,
num_return_sequences=request.n
)
usage = ChatCompletionResponseUsage(
@ -81,13 +93,13 @@ def create_app(chat_model: ChatModel) -> FastAPI:
total_tokens=prompt_length+response_length
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role=Role.ASSISTANT, content=response),
choices = [ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=choice),
finish_reason=Finish.STOP
)
) for i, choice in enumerate(response)]
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
@ -96,7 +108,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
for new_text in chat_model.stream_chat(
query, history, system,
@ -114,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@ -122,7 +134,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
yield "[DONE]"
return app

View File

@ -20,9 +20,6 @@ class ModelCard(BaseModel):
object: Optional[str] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = []
class ModelList(BaseModel):

View File

@ -26,17 +26,17 @@ class ChatModel:
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
@ -47,11 +47,15 @@ class ChatModel:
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
if int(num_return_sequences) > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
@ -75,12 +79,16 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
) -> Tuple[List[str], Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response_length = len(outputs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response_length = 0
for i in range(len(response_ids)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
return response, (prompt_length, response_length)
@torch.inference_mode()

View File

@ -83,7 +83,7 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2)
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@ -193,7 +193,7 @@ def load_model_and_tokenizer(
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
# Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
@ -212,8 +212,7 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")

View File

@ -180,15 +180,15 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token
response_length = response_index[-1].item() + 2 # save the EOS token
else:
response_length = response_index[-1] + 1
response_length = response_index[-1].item() + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
@ -216,7 +216,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")
@ -266,7 +266,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j])
if response_masks is not None:

View File

@ -7,13 +7,9 @@ if TYPE_CHECKING:
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
preview_count = gr.Number(interactive=False)
with gr.Row():
preview_samples = gr.JSON(interactive=False)
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)

View File

@ -6,10 +6,10 @@ CSS = r"""
transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px;
max-height: 750px;
overflow-y: scroll !important;
background-color: var(--input-background-fill);
border: 2px solid black !important;
z-index: 1000;
padding: 10px;
}
.dark .modal-box {

View File

@ -14,7 +14,7 @@ from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2")
def create_ui() -> gr.Blocks: