From 84af10cec912309b82fc16e20879eec1c2e3729a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 1 Nov 2023 23:02:16 +0800 Subject: [PATCH] update gradio, support multiple resp in api --- requirements.txt | 2 +- src/api_demo.py | 2 +- src/llmtuner/api/app.py | 36 ++++++++++++++++++--------- src/llmtuner/api/protocol.py | 3 --- src/llmtuner/chat/stream_chat.py | 22 ++++++++++------ src/llmtuner/tuner/core/loader.py | 7 +++--- src/llmtuner/tuner/ppo/trainer.py | 10 ++++---- src/llmtuner/webui/components/data.py | 10 +++----- src/llmtuner/webui/css.py | 2 +- src/llmtuner/webui/interface.py | 2 +- 10 files changed, 54 insertions(+), 42 deletions(-) diff --git a/requirements.txt b/requirements.txt index 61d5d279..587678a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ fire jieba rouge-chinese nltk -gradio==3.38.0 +gradio==3.50.2 uvicorn pydantic==1.10.11 fastapi==0.95.1 diff --git a/src/api_demo.py b/src/api_demo.py index 777f9dcf..720089fd 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -6,8 +6,8 @@ from llmtuner import ChatModel, create_app def main(): chat_model = ChatModel() app = create_app(chat_model) - uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) print("Visit http://localhost:8000/docs for API document.") + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) if __name__ == "__main__": diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index fdde591c..b60333d5 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,9 +1,11 @@ +import json import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from sse_starlette import EventSourceResponse from typing import List, Tuple +from pydantic import BaseModel from llmtuner.extras.misc import torch_gc from llmtuner.chat import ChatModel @@ -29,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory torch_gc() +def to_json(data: BaseModel) -> str: + try: + return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) + except: + return data.json(exclude_unset=True, ensure_ascii=False) + + def create_app(chat_model: ChatModel) -> FastAPI: app = FastAPI(lifespan=lifespan) @@ -45,10 +54,10 @@ def create_app(chat_model: ChatModel) -> FastAPI: model_card = ModelCard(id="gpt-3.5-turbo") return ModelList(data=[model_card]) - @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) async def create_chat_completion(request: ChatCompletionRequest): if len(request.messages) < 1 or request.messages[-1].role != Role.USER: - raise HTTPException(status_code=400, detail="Invalid request") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] @@ -62,6 +71,8 @@ def create_app(chat_model: ChatModel) -> FastAPI: for i in range(0, len(prev_messages), 2): if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: history.append([prev_messages[i].content, prev_messages[i+1].content]) + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") if request.stream: generate = predict(query, history, system, request) @@ -72,7 +83,8 @@ def create_app(chat_model: ChatModel) -> FastAPI: do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, - max_new_tokens=request.max_tokens + max_new_tokens=request.max_tokens, + num_return_sequences=request.n ) usage = ChatCompletionResponseUsage( @@ -81,13 +93,13 @@ def create_app(chat_model: ChatModel) -> FastAPI: total_tokens=prompt_length+response_length ) - choice_data = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role=Role.ASSISTANT, content=response), + choices = [ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role=Role.ASSISTANT, content=choice), finish_reason=Finish.STOP - ) + ) for i, choice in enumerate(response)] - return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( @@ -96,7 +108,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield to_json(chunk) for new_text in chat_model.stream_chat( query, history, system, @@ -114,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield to_json(chunk) choice_data = ChatCompletionResponseStreamChoice( index=0, @@ -122,7 +134,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: finish_reason=Finish.STOP ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield to_json(chunk) yield "[DONE]" return app diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index 1412af5f..6b99da40 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -20,9 +20,6 @@ class ModelCard(BaseModel): object: Optional[str] = "model" created: Optional[int] = Field(default_factory=lambda: int(time.time())) owned_by: Optional[str] = "owner" - root: Optional[str] = None - parent: Optional[str] = None - permission: Optional[list] = [] class ModelList(BaseModel): diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index af785dd3..938e9bcc 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -26,17 +26,17 @@ class ChatModel: **input_kwargs ) -> Tuple[Dict[str, Any], int]: system = system or self.system_prompt - prompt, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, query=query, resp="", history=history, system=system ) + prompt_length = len(prompt) input_ids = torch.tensor([prompt], device=self.model.device) - prompt_length = len(input_ids[0]) do_sample = input_kwargs.pop("do_sample", None) temperature = input_kwargs.pop("temperature", None) top_p = input_kwargs.pop("top_p", None) top_k = input_kwargs.pop("top_k", None) + num_return_sequences = input_kwargs.pop("num_return_sequences", None) repetition_penalty = input_kwargs.pop("repetition_penalty", None) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) @@ -47,11 +47,15 @@ class ChatModel: temperature=temperature or generating_args["temperature"], top_p=top_p or generating_args["top_p"], top_k=top_k or generating_args["top_k"], + num_return_sequences=num_return_sequences or 1, repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, pad_token_id=self.tokenizer.pad_token_id )) + if int(num_return_sequences) > 1: + generating_args["do_sample"] = True + if max_length: generating_args.pop("max_new_tokens", None) generating_args["max_length"] = max_length @@ -75,12 +79,16 @@ class ChatModel: history: Optional[List[Tuple[str, str]]] = None, system: Optional[str] = None, **input_kwargs - ) -> Tuple[str, Tuple[int, int]]: + ) -> Tuple[List[str], Tuple[int, int]]: gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) - generation_output = self.model.generate(**gen_kwargs) - outputs = generation_output.tolist()[0][prompt_length:] - response = self.tokenizer.decode(outputs, skip_special_tokens=True) - response_length = len(outputs) + generate_output = self.model.generate(**gen_kwargs) + response_ids = generate_output[:, prompt_length:] + response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + response_length = 0 + for i in range(len(response_ids)): + eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() + response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i]) + return response, (prompt_length, response_length) @torch.inference_mode() diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index a590c94b..6389c705 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -83,7 +83,7 @@ def load_model_and_tokenizer( config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) - # Fix tokenizer (for ChatGLM2) + # Fix tokenizer (for ChatGLM2 and ChatGLM3) if getattr(config, "model_type", None) == "chatglm": tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) @@ -193,7 +193,7 @@ def load_model_and_tokenizer( if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) - # Fix LM head (for ChatGLM2) + # Fix LM head (for ChatGLM2 and ChatGLM3) if getattr(config, "model_type", None) == "chatglm": setattr(model, "lm_head", model.transformer.output_layer) @@ -212,8 +212,7 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if stage == "rm" or stage == "ppo": - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - model._keys_to_ignore_on_save = None + model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) reset_logging() if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model logger.warning("Only the last checkpoint containing valuehead will be loaded.") diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 657d658c..baa36404 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -180,15 +180,15 @@ class CustomPPOTrainer(PPOTrainer, Trainer): query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() queries, responses = [], [] for i in range(len(query)): - query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] + query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() if len(response_index) == 0: response_length = 1 # allow empty response elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - response_length = response_index[-1] + 2 # save the EOS token + response_length = response_index[-1].item() + 2 # save the EOS token else: - response_length = response_index[-1] + 1 + response_length = response_index[-1].item() + 1 queries.append(query[i, query_length:]) # remove padding from left responses.append(response[i, :response_length]) # remove padding from right @@ -216,7 +216,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): rewards = [] for i in range(values.size(0)): - end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token + end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type replace_model(unwrapped_model, target="default") @@ -266,7 +266,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): for j in range(len(query_batch)): start = len(query_batch[j]) - 1 if attention_mask[j, 0] == 0: # offset left padding - start += attention_mask[j, :].nonzero()[0] + start += attention_mask[j, :].nonzero()[0].item() end = start + len(response_batch[j]) if response_masks is not None: diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index af19cc41..745f7648 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -7,13 +7,9 @@ if TYPE_CHECKING: def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]: - with gr.Box(visible=False, elem_classes="modal-box") as preview_box: - with gr.Row(): - preview_count = gr.Number(interactive=False) - - with gr.Row(): - preview_samples = gr.JSON(interactive=False) - + with gr.Column(visible=False, elem_classes="modal-box") as preview_box: + preview_count = gr.Number(interactive=False) + preview_samples = gr.JSON(interactive=False) close_btn = gr.Button() close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) diff --git a/src/llmtuner/webui/css.py b/src/llmtuner/webui/css.py index 5d370c1f..6dab6ffa 100644 --- a/src/llmtuner/webui/css.py +++ b/src/llmtuner/webui/css.py @@ -6,10 +6,10 @@ CSS = r""" transform: translate(-50%, -50%); /* center horizontally */ max-width: 1000px; max-height: 750px; - overflow-y: scroll !important; background-color: var(--input-background-fill); border: 2px solid black !important; z-index: 1000; + padding: 10px; } .dark .modal-box { diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index eaf9a9fd..b9292891 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -14,7 +14,7 @@ from llmtuner.webui.css import CSS from llmtuner.webui.engine import Engine -require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0") +require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2") def create_ui() -> gr.Blocks: