diff --git a/README.md b/README.md index d0738eaa..e4716873 100644 --- a/README.md +++ b/README.md @@ -264,8 +264,8 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.8 | 3.10 | | torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.39.1 | -| datasets | 2.14.3 | 2.17.1 | +| transformers | 4.37.2 | 4.39.2 | +| datasets | 2.14.3 | 2.18.0 | | accelerate | 0.27.2 | 0.28.0 | | peft | 0.9.0 | 0.10.0 | | trl | 0.8.1 | 0.8.1 | diff --git a/README_zh.md b/README_zh.md index 460784b9..b13c0f19 100644 --- a/README_zh.md +++ b/README_zh.md @@ -264,8 +264,8 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.8 | 3.10 | | torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.39.1 | -| datasets | 2.14.3 | 2.17.1 | +| transformers | 4.37.2 | 4.39.2 | +| datasets | 2.14.3 | 2.18.0 | | accelerate | 0.27.2 | 0.28.0 | | peft | 0.9.0 | 0.10.0 | | trl | 0.8.1 | 0.8.1 | diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index c5a18bc7..3f06fef1 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") - input_messages.append({"role": role_mapping[message.role], "content": message.content}) + if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): + name = message.tool_calls[0].function.name + arguments = message.tool_calls[0].function.arguments + content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) + input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content}) + else: + input_messages.append({"role": role_mapping[message.role], "content": message.content}) tool_list = request.tools if isinstance(tool_list, list) and len(tool_list): try: - tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False) + tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) except Exception: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index 3e39fe0b..ece2132b 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -1,6 +1,6 @@ import time from enum import Enum, unique -from typing import List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from typing_extensions import Literal @@ -39,6 +39,17 @@ class Function(BaseModel): arguments: str +class FunctionDefinition(BaseModel): + name: str + description: str + parameters: Dict[str, Any] + + +class FunctionAvailable(BaseModel): + type: Literal["function", "code_interpreter"] = "function" + function: Optional[FunctionDefinition] = None + + class FunctionCall(BaseModel): id: Literal["call_default"] = "call_default" type: Literal["function"] = "function" @@ -47,7 +58,8 @@ class FunctionCall(BaseModel): class ChatMessage(BaseModel): role: Role - content: str + content: Optional[str] = None + tool_calls: Optional[List[FunctionCall]] = None class ChatCompletionMessage(BaseModel): @@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel): class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] - tools: list = [] + tools: Optional[List[FunctionAvailable]] = None do_sample: bool = True temperature: Optional[float] = None top_p: Optional[float] = None diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index 7fb9f9d6..f0b65d65 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -193,6 +193,6 @@ def llama_flash_attn_forward( def apply_llama_patch() -> None: - require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1") + require_version("transformers==4.39.2", "To fix: pip install transformers==4.39.2") LlamaAttention.forward = llama_torch_attn_forward LlamaFlashAttention2.forward = llama_flash_attn_forward diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 1a6da78a..97399a2c 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -331,7 +331,7 @@ def patch_model( ): gen_config.do_sample = True - if model_args.resize_vocab: + if is_trainable and model_args.resize_vocab: _resize_embedding_layer(model, tokenizer) if is_trainable: diff --git a/tests/test_toolcall.py b/tests/test_toolcall.py index a54a0053..d36e7fec 100644 --- a/tests/test_toolcall.py +++ b/tests/test_toolcall.py @@ -15,7 +15,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: for grade, hour in zip(grades, hours): total_score += grade_to_score[grade] * hour total_hour += hour - return total_score / total_hour + return round(total_score / total_hour, 2) def main(): @@ -45,16 +45,19 @@ def main(): messages = [] messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."}) result = client.chat.completions.create(messages=messages, model="test", tools=tools) + if result.choices[0].message.tool_calls is None: + raise ValueError("Cannot retrieve function call from the response.") + + messages.append(result.choices[0].message) tool_call = result.choices[0].message.tool_calls[0].function + print(tool_call) + # Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa') name, arguments = tool_call.name, json.loads(tool_call.arguments) - messages.append( - {"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)} - ) tool_result = tool_map[name](**arguments) messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)}) result = client.chat.completions.create(messages=messages, model="test", tools=tools) print(result.choices[0].message.content) - # Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665. + # Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42. if __name__ == "__main__":