diff --git a/src/cli_demo.py b/src/cli_demo.py index 96007f1a..ba828f51 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -29,7 +29,7 @@ def main(): break if query.strip() == "clear": - history = [] + messages = [] torch_gc() print("History has been removed.") continue diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index d50e9137..973620af 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -2,7 +2,7 @@ import asyncio import json import os from contextlib import asynccontextmanager -from typing import List, Tuple +from typing import Any, Dict, Sequence from pydantic import BaseModel @@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory torch_gc() -def to_json(data: BaseModel) -> str: +def dictify(data: "BaseModel") -> Dict[str, Any]: + try: # pydantic v2 + return data.model_dump(exclude_unset=True) + except AttributeError: # pydantic v1 + return data.dict(exclude_unset=True) + + +def jsonify(data: "BaseModel") -> str: try: # pydantic v2 return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) - except Exception: # pydantic v1 + except AttributeError: # pydantic v1 return data.json(exclude_unset=True, ensure_ascii=False) @@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if len(request.messages) == 0 or request.messages[-1].role != Role.USER: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") - query = request.messages[-1].content - prev_messages = request.messages[:-1] - if len(prev_messages) and prev_messages[0].role == Role.SYSTEM: - system = prev_messages.pop(0).content + messages = [dictify(message) for message in request.messages] + if len(messages) and messages[0]["role"] == Role.SYSTEM: + system = messages.pop(0)["content"] else: system = None - history = [] - if len(prev_messages) % 2 == 0: - for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT: - history.append([prev_messages[i].content, prev_messages[i + 1].content]) - else: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") - else: + if len(messages) % 2 == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + for i in range(len(messages)): + if messages[i]["role"] == Role.USER: + if i % 2 == 1: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + elif messages[i]["role"] == Role.ASSISTANT: + if i % 2 == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + else: + raise NotImplementedError + + tools = "" # TODO: add tools + async with semaphore: loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, chat_completion, query, history, system, request) + return await loop.run_in_executor(None, chat_completion, messages, system, tools, request) - def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): + def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest): if request.stream: - generate = stream_chat_completion(query, history, system, request) + generate = stream_chat_completion(messages, system, tools, request) return EventSourceResponse(generate, media_type="text/event-stream") responses = chat_model.chat( - query, - history, + messages, system, + tools, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, @@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) def stream_chat_completion( - query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest + messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest ): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield to_json(chunk) + yield jsonify(chunk) for new_text in chat_model.stream_chat( - query, - history, + messages, system, + tools, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, @@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield to_json(chunk) + yield jsonify(chunk) choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield to_json(chunk) + yield jsonify(chunk) yield "[DONE]" @app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK) diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index 862f8b77..94c9acce 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -3,6 +3,7 @@ from enum import Enum, unique from typing import List, Optional from pydantic import BaseModel, Field +from typing_extensions import Literal @unique @@ -20,14 +21,14 @@ class Finish(str, Enum): class ModelCard(BaseModel): id: str - object: Optional[str] = "model" - created: Optional[int] = Field(default_factory=lambda: int(time.time())) - owned_by: Optional[str] = "owner" + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: Literal["owner"] = "owner" class ModelList(BaseModel): - object: Optional[str] = "list" - data: Optional[List[ModelCard]] = [] + object: Literal["list"] = "list" + data: List[ModelCard] = [] class ChatMessage(BaseModel): @@ -43,12 +44,12 @@ class DeltaMessage(BaseModel): class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] - do_sample: Optional[bool] = True + do_sample: bool = True temperature: Optional[float] = None top_p: Optional[float] = None - n: Optional[int] = 1 + n: int = 1 max_tokens: Optional[int] = None - stream: Optional[bool] = False + stream: bool = False class ChatCompletionResponseChoice(BaseModel): @@ -70,18 +71,18 @@ class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponse(BaseModel): - id: Optional[str] = "chatcmpl-default" - object: Optional[str] = "chat.completion" - created: Optional[int] = Field(default_factory=lambda: int(time.time())) + id: Literal["chatcmpl-default"] = "chatcmpl-default" + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] usage: ChatCompletionResponseUsage class ChatCompletionStreamResponse(BaseModel): - id: Optional[str] = "chatcmpl-default" - object: Optional[str] = "chat.completion.chunk" - created: Optional[int] = Field(default_factory=lambda: int(time.time())) + id: Literal["chatcmpl-default"] = "chatcmpl-default" + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] @@ -93,7 +94,7 @@ class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationResponse(BaseModel): - id: Optional[str] = "scoreeval-default" - object: Optional[str] = "score.evaluation" + id: Literal["scoreeval-default"] = "scoreeval-default" + object: Literal["score.evaluation"] = "score.evaluation" model: str scores: List[float] diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index 078539c2..934cb904 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -2,7 +2,7 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union +from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index a000a7f9..85516f98 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -144,8 +144,8 @@ class Template: max_len=(cutoff_len - total_length), reserved_label_len=reserved_label_len, ) - encoded_messages[i] = encoded_messages[i][: max_source_len] - encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len] + encoded_messages[i] = encoded_messages[i][:max_source_len] + encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len] total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1]) encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) @@ -416,7 +416,7 @@ register_template( "by the user such as English and 中文." ), stop_words=["<|im_end|>"], - efficient_eos=True, + efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) @@ -455,9 +455,7 @@ register_template( register_template( name="openchat", - format_user=StringFormatter( - slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"] - ), + format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_assistant=StringFormatter(slots=["{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), force_system=True,