fix api
This commit is contained in:
parent
a9c18255aa
commit
55f707196e
|
@ -29,7 +29,7 @@ def main():
|
||||||
break
|
break
|
||||||
|
|
||||||
if query.strip() == "clear":
|
if query.strip() == "clear":
|
||||||
history = []
|
messages = []
|
||||||
torch_gc()
|
torch_gc()
|
||||||
print("History has been removed.")
|
print("History has been removed.")
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import List, Tuple
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -46,10 +46,17 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def to_json(data: BaseModel) -> str:
|
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||||
|
try: # pydantic v2
|
||||||
|
return data.model_dump(exclude_unset=True)
|
||||||
|
except AttributeError: # pydantic v1
|
||||||
|
return data.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
|
||||||
|
def jsonify(data: "BaseModel") -> str:
|
||||||
try: # pydantic v2
|
try: # pydantic v2
|
||||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||||
except Exception: # pydantic v1
|
except AttributeError: # pydantic v1
|
||||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,36 +86,40 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
query = request.messages[-1].content
|
messages = [dictify(message) for message in request.messages]
|
||||||
prev_messages = request.messages[:-1]
|
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
system = messages.pop(0)["content"]
|
||||||
system = prev_messages.pop(0).content
|
|
||||||
else:
|
else:
|
||||||
system = None
|
system = None
|
||||||
|
|
||||||
history = []
|
if len(messages) % 2 == 0:
|
||||||
if len(prev_messages) % 2 == 0:
|
|
||||||
for i in range(0, len(prev_messages), 2):
|
|
||||||
if prev_messages[i].role == Role.USER and prev_messages[i + 1].role == Role.ASSISTANT:
|
|
||||||
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
|
for i in range(len(messages)):
|
||||||
|
if messages[i]["role"] == Role.USER:
|
||||||
|
if i % 2 == 1:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
elif messages[i]["role"] == Role.ASSISTANT:
|
||||||
|
if i % 2 == 0:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
tools = "" # TODO: add tools
|
||||||
|
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
|
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
|
||||||
|
|
||||||
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = stream_chat_completion(query, history, system, request)
|
generate = stream_chat_completion(messages, system, tools, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = chat_model.chat(
|
||||||
query,
|
messages,
|
||||||
history,
|
|
||||||
system,
|
system,
|
||||||
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
|
@ -138,18 +149,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
def stream_chat_completion(
|
def stream_chat_completion(
|
||||||
query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest
|
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
query,
|
messages,
|
||||||
history,
|
|
||||||
system,
|
system,
|
||||||
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
|
@ -162,11 +173,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from enum import Enum, unique
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
|
@ -20,14 +21,14 @@ class Finish(str, Enum):
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Optional[str] = "model"
|
object: Literal["model"] = "model"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: Optional[str] = "owner"
|
owned_by: Literal["owner"] = "owner"
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(BaseModel):
|
||||||
object: Optional[str] = "list"
|
object: Literal["list"] = "list"
|
||||||
data: Optional[List[ModelCard]] = []
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
@ -43,12 +44,12 @@ class DeltaMessage(BaseModel):
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
do_sample: Optional[bool] = True
|
do_sample: bool = True
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
n: Optional[int] = 1
|
n: int = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
@ -70,18 +71,18 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||||
object: Optional[str] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
usage: ChatCompletionResponseUsage
|
usage: ChatCompletionResponseUsage
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||||
object: Optional[str] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
|
@ -93,7 +94,7 @@ class ScoreEvaluationRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ScoreEvaluationResponse(BaseModel):
|
class ScoreEvaluationResponse(BaseModel):
|
||||||
id: Optional[str] = "scoreeval-default"
|
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||||
object: Optional[str] = "score.evaluation"
|
object: Literal["score.evaluation"] = "score.evaluation"
|
||||||
model: str
|
model: str
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Set, Sequence, Tuple, Union
|
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
|
@ -144,8 +144,8 @@ class Template:
|
||||||
max_len=(cutoff_len - total_length),
|
max_len=(cutoff_len - total_length),
|
||||||
reserved_label_len=reserved_label_len,
|
reserved_label_len=reserved_label_len,
|
||||||
)
|
)
|
||||||
encoded_messages[i] = encoded_messages[i][: max_source_len]
|
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
||||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max_target_len]
|
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
||||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
||||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||||
|
|
||||||
|
@ -416,7 +416,7 @@ register_template(
|
||||||
"by the user such as English and 中文."
|
"by the user such as English and 中文."
|
||||||
),
|
),
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -455,9 +455,7 @@ register_template(
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||||
slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]
|
|
||||||
),
|
|
||||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
force_system=True,
|
force_system=True,
|
||||||
|
|
Loading…
Reference in New Issue