This commit is contained in:
hiyouga 2024-01-21 00:03:09 +08:00
parent a9c18255aa
commit 55f707196e
5 changed files with 60 additions and 50 deletions

View File

@ -29,7 +29,7 @@ def main():
break break
if query.strip() == "clear": if query.strip() == "clear":
history = [] messages = []
torch_gc() torch_gc()
print("History has been removed.") print("History has been removed.")
continue continue

View File

@ -2,7 +2,7 @@ import asyncio
import json import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List, Tuple from typing import Any, Dict, Sequence
from pydantic import BaseModel from pydantic import BaseModel
@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
torch_gc() torch_gc()
def to_json(data: BaseModel) -> str: def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2 try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except Exception: # pydantic v1 except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)
@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0 or request.messages[-1].role != Role.USER: if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content messages = [dictify(message) for message in request.messages]
prev_messages = request.messages[:-1] if len(messages) and messages[0]["role"] == Role.SYSTEM:
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM: system = messages.pop(0)["content"]
system = prev_messages.pop(0).content
else: else:
system = None system = None
history = [] if len(messages) % 2 == 0:
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i + 1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)):
if messages[i]["role"] == Role.USER:
if i % 2 == 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
elif messages[i]["role"] == Role.ASSISTANT:
if i % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
else:
raise NotImplementedError
tools = "" # TODO: add tools
async with semaphore: async with semaphore:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, query, history, system, request) return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = stream_chat_completion(query, history, system, request) generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat( responses = chat_model.chat(
query, messages,
history,
system, system,
tools,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion( def stream_chat_completion(
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
): ):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield jsonify(chunk)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, messages,
history,
system, system,
tools,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
index=0, delta=DeltaMessage(content=new_text), finish_reason=None index=0, delta=DeltaMessage(content=new_text), finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP) choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield to_json(chunk) yield jsonify(chunk)
yield "[DONE]" yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK) @app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)

View File

@ -3,6 +3,7 @@ from enum import Enum, unique
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique @unique
@ -20,14 +21,14 @@ class Finish(str, Enum):
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: Optional[str] = "model" object: Literal["model"] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner" owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel): class ModelList(BaseModel):
object: Optional[str] = "list" object: Literal["list"] = "list"
data: Optional[List[ModelCard]] = [] data: List[ModelCard] = []
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@ -43,12 +44,12 @@ class DeltaMessage(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
do_sample: Optional[bool] = True do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
n: Optional[int] = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: Optional[bool] = False stream: bool = False
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
@ -70,18 +71,18 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Optional[str] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Literal["chatcmpl-default"] = "chatcmpl-default"
object: Optional[str] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
@ -93,7 +94,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel): class ScoreEvaluationResponse(BaseModel):
id: Optional[str] = "scoreeval-default" id: Literal["scoreeval-default"] = "scoreeval-default"
object: Optional[str] = "score.evaluation" object: Literal["score.evaluation"] = "score.evaluation"
model: str model: str
scores: List[float] scores: List[float]

View File

@ -2,7 +2,7 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]

View File

@ -144,8 +144,8 @@ class Template:
max_len=(cutoff_len - total_length), max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len, reserved_label_len=reserved_label_len,
) )
encoded_messages[i] = encoded_messages[i][: max_source_len] encoded_messages[i] = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len] encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1]) total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
@ -416,7 +416,7 @@ register_template(
"by the user such as English and 中文." "by the user such as English and 中文."
), ),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
efficient_eos=True, efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
) )
@ -455,9 +455,7 @@ register_template(
register_template( register_template(
name="openchat", name="openchat",
format_user=StringFormatter( format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
),
format_assistant=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True, force_system=True,