update gradio, support multiple resp in api

This commit is contained in:
hiyouga 2023-11-01 23:02:16 +08:00
parent d8cf8cfdeb
commit 84af10cec9
10 changed files with 54 additions and 42 deletions

View File

@ -12,7 +12,7 @@ fire
jieba jieba
rouge-chinese rouge-chinese
nltk nltk
gradio==3.38.0 gradio==3.50.2
uvicorn uvicorn
pydantic==1.10.11 pydantic==1.10.11
fastapi==0.95.1 fastapi==0.95.1

View File

@ -6,8 +6,8 @@ from llmtuner import ChatModel, create_app
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
print("Visit http://localhost:8000/docs for API document.") print("Visit http://localhost:8000/docs for API document.")
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,9 +1,11 @@
import json
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from typing import List, Tuple from typing import List, Tuple
from pydantic import BaseModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
@ -29,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory
torch_gc() torch_gc()
def to_json(data: BaseModel) -> str:
try:
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except:
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: ChatModel) -> FastAPI: def create_app(chat_model: ChatModel) -> FastAPI:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@ -45,10 +54,10 @@ def create_app(chat_model: ChatModel) -> FastAPI:
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER: if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
@ -62,6 +71,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
for i in range(0, len(prev_messages), 2): for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream: if request.stream:
generate = predict(query, history, system, request) generate = predict(query, history, system, request)
@ -72,7 +83,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens max_new_tokens=request.max_tokens,
num_return_sequences=request.n
) )
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
@ -81,13 +93,13 @@ def create_app(chat_model: ChatModel) -> FastAPI:
total_tokens=prompt_length+response_length total_tokens=prompt_length+response_length
) )
choice_data = ChatCompletionResponseChoice( choices = [ChatCompletionResponseChoice(
index=0, index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response), message=ChatMessage(role=Role.ASSISTANT, content=choice),
finish_reason=Finish.STOP finish_reason=Finish.STOP
) ) for i, choice in enumerate(response)]
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
@ -96,7 +108,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield to_json(chunk)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, system, query, history, system,
@ -114,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -122,7 +134,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=Finish.STOP finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield to_json(chunk)
yield "[DONE]" yield "[DONE]"
return app return app

View File

@ -20,9 +20,6 @@ class ModelCard(BaseModel):
object: Optional[str] = "model" object: Optional[str] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner" owned_by: Optional[str] = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = []
class ModelList(BaseModel): class ModelList(BaseModel):

View File

@ -26,17 +26,17 @@ class ChatModel:
**input_kwargs **input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
system = system or self.system_prompt system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
) )
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None) do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None) temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None) top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None) top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None) repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
@ -47,11 +47,15 @@ class ChatModel:
temperature=temperature or generating_args["temperature"], temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"], top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"], top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id pad_token_id=self.tokenizer.pad_token_id
)) ))
if int(num_return_sequences) > 1:
generating_args["do_sample"] = True
if max_length: if max_length:
generating_args.pop("max_new_tokens", None) generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length generating_args["max_length"] = max_length
@ -75,12 +79,16 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[List[str], Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:] response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True) response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response_length = len(outputs) response_length = 0
for i in range(len(response_ids)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
return response, (prompt_length, response_length) return response, (prompt_length, response_length)
@torch.inference_mode() @torch.inference_mode()

View File

@ -83,7 +83,7 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2) # Fix tokenizer (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@ -193,7 +193,7 @@ def load_model_and_tokenizer(
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2) # Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "lm_head", model.transformer.output_layer)
@ -212,8 +212,7 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF # Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo": if stage == "rm" or stage == "ppo":
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging() reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.") logger.warning("Only the last checkpoint containing valuehead will be loaded.")

View File

@ -180,15 +180,15 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], [] queries, responses = [], []
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0: if len(response_index) == 0:
response_length = 1 # allow empty response response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token response_length = response_index[-1].item() + 2 # save the EOS token
else: else:
response_length = response_index[-1] + 1 response_length = response_index[-1].item() + 1
queries.append(query[i, query_length:]) # remove padding from left queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right responses.append(response[i, :response_length]) # remove padding from right
@ -216,7 +216,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = [] rewards = []
for i in range(values.size(0)): for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
@ -266,7 +266,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)): for j in range(len(query_batch)):
start = len(query_batch[j]) - 1 start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0] start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j]) end = start + len(response_batch[j])
if response_masks is not None: if response_masks is not None:

View File

@ -7,13 +7,9 @@ if TYPE_CHECKING:
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]: def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box: with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row(): preview_count = gr.Number(interactive=False)
preview_count = gr.Number(interactive=False) preview_samples = gr.JSON(interactive=False)
with gr.Row():
preview_samples = gr.JSON(interactive=False)
close_btn = gr.Button() close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)

View File

@ -6,10 +6,10 @@ CSS = r"""
transform: translate(-50%, -50%); /* center horizontally */ transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px; max-width: 1000px;
max-height: 750px; max-height: 750px;
overflow-y: scroll !important;
background-color: var(--input-background-fill); background-color: var(--input-background-fill);
border: 2px solid black !important; border: 2px solid black !important;
z-index: 1000; z-index: 1000;
padding: 10px;
} }
.dark .modal-box { .dark .modal-box {

View File

@ -14,7 +14,7 @@ from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine from llmtuner.webui.engine import Engine
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0") require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2")
def create_ui() -> gr.Blocks: def create_ui() -> gr.Blocks: