fix api
This commit is contained in:
parent
a9c18255aa
commit
55f707196e
|
@ -29,7 +29,7 @@ def main():
|
|||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
history = []
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
|
|||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except Exception: # pydantic v1
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
|
@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
||||
system = prev_messages.pop(0).content
|
||||
messages = [dictify(message) for message in request.messages]
|
||||
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = None
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
else:
|
||||
if len(messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == Role.USER:
|
||||
if i % 2 == 1:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
elif messages[i]["role"] == Role.ASSISTANT:
|
||||
if i % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
tools = "" # TODO: add tools
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
|
||||
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
|
||||
|
||||
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
generate = stream_chat_completion(query, history, system, request)
|
||||
generate = stream_chat_completion(messages, system, tools, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
responses = chat_model.chat(
|
||||
query,
|
||||
history,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
|
@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
def stream_chat_completion(
|
||||
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query,
|
||||
history,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
|
@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
|
|
|
@ -3,6 +3,7 @@ from enum import Enum, unique
|
|||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@unique
|
||||
|
@ -20,14 +21,14 @@ class Finish(str, Enum):
|
|||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
object: Literal["model"] = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Optional[str] = "list"
|
||||
data: Optional[List[ModelCard]] = []
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
@ -43,12 +44,12 @@ class DeltaMessage(BaseModel):
|
|||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
do_sample: Optional[bool] = True
|
||||
do_sample: bool = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
|
@ -70,18 +71,18 @@ class ChatCompletionResponseUsage(BaseModel):
|
|||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Optional[str] = "chat.completion"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Optional[str] = "chat.completion.chunk"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
@ -93,7 +94,7 @@ class ScoreEvaluationRequest(BaseModel):
|
|||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Optional[str] = "scoreeval-default"
|
||||
object: Optional[str] = "score.evaluation"
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
|
|
@ -144,8 +144,8 @@ class Template:
|
|||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
encoded_messages[i] = encoded_messages[i][: max_source_len]
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len]
|
||||
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
|
@ -416,7 +416,7 @@ register_template(
|
|||
"by the user such as English and 中文."
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
efficient_eos=True,
|
||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||
)
|
||||
|
||||
|
||||
|
@ -455,9 +455,7 @@ register_template(
|
|||
|
||||
register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(
|
||||
slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
|
|
Loading…
Reference in New Issue