finish agent

This commit is contained in:
hiyouga 2024-01-21 01:47:33 +08:00
parent 55f707196e
commit 3e982cc714
8 changed files with 105 additions and 41 deletions

View File

@ -7,18 +7,20 @@ from typing import Any, Dict, Sequence
from pydantic import BaseModel
from ..chat import ChatModel
from ..data import Role as DataRole
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
Finish,
Function,
FunctionCall,
ModelCard,
ModelList,
Role,
@ -84,7 +86,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
messages = [dictify(message) for message in request.messages]
if len(messages) and messages[0]["role"] == Role.SYSTEM:
@ -96,16 +98,21 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)):
if messages[i]["role"] == Role.USER:
if i % 2 == 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
elif messages[i]["role"] == Role.ASSISTANT:
if i % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
else:
raise NotImplementedError
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif messages[i]["role"] == Role.TOOL:
messages[i]["role"] = DataRole.OBSERVATION
tools = "" # TODO: add tools
tool_list = request.tools
if len(tool_list):
try:
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
async with semaphore:
loop = asyncio.get_running_loop()
@ -130,12 +137,24 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
if tools:
result = chat_model.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
)
prompt_length = response.prompt_length
response_length += response.response_length
@ -152,7 +171,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
@ -170,12 +189,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
yield "[DONE]"

View File

@ -11,12 +11,15 @@ class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel):
@ -31,19 +34,32 @@ class ModelList(BaseModel):
data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
function: Function
class ChatMessage(BaseModel):
role: Role
content: str
class DeltaMessage(BaseModel):
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: Optional[list] = []
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None
@ -54,13 +70,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
message: ChatCompletionMessage
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None

View File

@ -37,9 +37,9 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
new_messages = messages + [{"role": "assistant", "content": ""}]
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=new_messages, system=system, tools=tools
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)

View File

@ -74,6 +74,9 @@ class Formatter(ABC):
def apply(self, **kwargs) -> SLOTS:
...
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError
@dataclass
class EmptyFormatter(Formatter):

View File

@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import json
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
messages: Sequence[Tuple[str, str]],
system: str,
tools: str,
max_new_tokens: int,
top_p: float,
temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
chatbot.append([query, ""])
query_messages = messages + [{"role": Role.USER, "content": query}]
response = ""
for new_text in self.stream_chat(
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = [query, self.postprocess(response)]
yield chatbot, new_history
if tools:
result = self.template.format_tools.extract(response)
else:
result = response
if isinstance(result, tuple):
name, arguments = result
arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```"
else:
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
bot_text = result
chatbot[-1] = [query, self.postprocess(bot_text)]
yield chatbot, output_messages
def postprocess(self, response: str) -> str:
blocks = response.split("```")

View File

@ -17,7 +17,7 @@ def create_chat_box(
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
history = gr.State([])
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
system = gr.Textbox(show_label=False)
@ -32,21 +32,21 @@ def create_chat_box(
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
tools.input(check_json_schema, [tools])
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
submit_btn.click(
engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history],
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, messages],
show_progress=True,
).then(lambda: gr.update(value=""), outputs=[query])
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
return (
chat_box,
chatbot,
history,
messages,
dict(
system=system,
tools=tools,

View File

@ -208,6 +208,8 @@ ALERTS = {
"zh": "展示模式不支持训练,请先复制到私人空间。",
},
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
"err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"},
"err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"},
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},

View File

@ -8,6 +8,7 @@ import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True)
def check_json_schema(text: str) -> None:
def check_json_schema(text: str, lang: str) -> None:
try:
json.loads(text)
tools = json.loads(text)
for tool in tools:
assert "name" in tool
except AssertionError:
gr.Warning(ALERTS["err_tool_name"][lang])
except json.JSONDecodeError:
gr.Warning("Invalid JSON schema")
gr.Warning(ALERTS["err_json_schema"][lang])
def gen_cmd(args: Dict[str, Any]) -> str: