forked from p04798526/LLaMA-Factory-Mirror
finish agent
This commit is contained in:
parent
55f707196e
commit
3e982cc714
|
@ -7,18 +7,20 @@ from typing import Any, Dict, Sequence
|
|||
from pydantic import BaseModel
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
|
@ -84,7 +86,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
messages = [dictify(message) for message in request.messages]
|
||||
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||
|
@ -96,16 +98,21 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == Role.USER:
|
||||
if i % 2 == 1:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
elif messages[i]["role"] == Role.ASSISTANT:
|
||||
if i % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif messages[i]["role"] == Role.TOOL:
|
||||
messages[i]["role"] = DataRole.OBSERVATION
|
||||
|
||||
tools = "" # TODO: add tools
|
||||
tool_list = request.tools
|
||||
if len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
@ -130,12 +137,24 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
|
||||
if tools:
|
||||
result = chat_model.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
response_message = ChatCompletionMessage(
|
||||
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||
)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
@ -152,7 +171,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
@ -170,12 +189,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
||||
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
|
|
@ -11,12 +11,15 @@ class Role(str, Enum):
|
|||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@unique
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
TOOL = "tool_calls"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
|
@ -31,19 +34,32 @@ class ModelList(BaseModel):
|
|||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: Optional[list] = []
|
||||
do_sample: bool = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
|
@ -54,13 +70,13 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
||||
|
||||
|
|
|
@ -37,9 +37,9 @@ class ChatModel:
|
|||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
new_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=new_messages, system=system, tools=tools
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
|
|
|
@ -74,6 +74,9 @@ class Formatter(ABC):
|
|||
def apply(self, **kwargs) -> SLOTS:
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyFormatter(Formatter):
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import GeneratingArguments
|
||||
from .common import get_save_dir
|
||||
|
@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
|
|||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, self.postprocess(response)]
|
||||
yield chatbot, new_history
|
||||
if tools:
|
||||
result = self.template.format_tools.extract(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
yield chatbot, output_messages
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
|
|
|
@ -17,7 +17,7 @@ def create_chat_box(
|
|||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
history = gr.State([])
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
system = gr.Textbox(show_label=False)
|
||||
|
@ -32,21 +32,21 @@ def create_chat_box(
|
|||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
|
||||
tools.input(check_json_schema, [tools])
|
||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
|
||||
|
||||
return (
|
||||
chat_box,
|
||||
chatbot,
|
||||
history,
|
||||
messages,
|
||||
dict(
|
||||
system=system,
|
||||
tools=tools,
|
||||
|
|
|
@ -208,6 +208,8 @@ ALERTS = {
|
|||
"zh": "展示模式不支持训练,请先复制到私人空间。",
|
||||
},
|
||||
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
|
||||
"err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"},
|
||||
"err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"},
|
||||
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
|
||||
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
|
||||
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
|
||||
|
|
|
@ -8,6 +8,7 @@ import gradio as gr
|
|||
from ..extras.packages import is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def check_json_schema(text: str) -> None:
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
try:
|
||||
json.loads(text)
|
||||
tools = json.loads(text)
|
||||
for tool in tools:
|
||||
assert "name" in tool
|
||||
except AssertionError:
|
||||
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||
except json.JSONDecodeError:
|
||||
gr.Warning("Invalid JSON schema")
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
|
|
Loading…
Reference in New Issue