This commit is contained in:
hiyouga 2024-01-21 00:03:09 +08:00
parent a9c18255aa
commit 55f707196e
5 changed files with 60 additions and 50 deletions

View File

@ -29,7 +29,7 @@ def main():
break
if query.strip() == "clear":
history = []
messages = []
torch_gc()
print("History has been removed.")
continue

View File

@ -2,7 +2,7 @@ import asyncio
import json
import os
from contextlib import asynccontextmanager
from typing import List, Tuple
from typing import Any, Dict, Sequence
from pydantic import BaseModel
@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
torch_gc()
def to_json(data: BaseModel) -> str:
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except Exception: # pydantic v1
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
messages = [dictify(message) for message in request.messages]
if len(messages) and messages[0]["role"] == Role.SYSTEM:
system = messages.pop(0)["content"]
else:
system = None
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i + 1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
if len(messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)):
if messages[i]["role"] == Role.USER:
if i % 2 == 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
elif messages[i]["role"] == Role.ASSISTANT:
if i % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
else:
raise NotImplementedError
tools = "" # TODO: add tools
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream:
generate = stream_chat_completion(query, history, system, request)
generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat(
query,
history,
messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion(
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield jsonify(chunk)
for new_text in chat_model.stream_chat(
query,
history,
messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk)
yield jsonify(chunk)
yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)

View File

@ -3,6 +3,7 @@ from enum import Enum, unique
from typing import List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
@ -20,14 +21,14 @@ class Finish(str, Enum):
class ModelCard(BaseModel):
id: str
object: Optional[str] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner"
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel):
object: Optional[str] = "list"
data: Optional[List[ModelCard]] = []
object: Literal["list"] = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
@ -43,12 +44,12 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
do_sample: Optional[bool] = True
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = 1
n: int = 1
max_tokens: Optional[int] = None
stream: Optional[bool] = False
stream: bool = False
class ChatCompletionResponseChoice(BaseModel):
@ -70,18 +71,18 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
@ -93,7 +94,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel):
id: Optional[str] = "scoreeval-default"
object: Optional[str] = "score.evaluation"
id: Literal["scoreeval-default"] = "scoreeval-default"
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]

View File

@ -2,7 +2,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]

View File

@ -144,8 +144,8 @@ class Template:
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][: max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len]
encoded_messages[i] = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
@ -416,7 +416,7 @@ register_template(
"by the user such as English and 中文."
),
stop_words=["<|im_end|>"],
efficient_eos=True,
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
@ -455,9 +455,7 @@ register_template(
register_template(
name="openchat",
format_user=StringFormatter(
slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
),
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,