This commit is contained in:
hiyouga 2024-04-01 21:35:18 +08:00
parent eb259cc573
commit aee634cd20
7 changed files with 37 additions and 16 deletions

View File

@ -264,8 +264,8 @@ huggingface-cli login
| ------------ | ------- | --------- |
| python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 |
| datasets | 2.14.3 | 2.17.1 |
| transformers | 4.37.2 | 4.39.2 |
| datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 |

View File

@ -264,8 +264,8 @@ huggingface-cli login
| ------------ | ------- | --------- |
| python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 |
| datasets | 2.14.3 | 2.17.1 |
| transformers | 4.37.2 | 4.39.2 |
| datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 |

View File

@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:

View File

@ -1,6 +1,6 @@
import time
from enum import Enum, unique
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from typing_extensions import Literal
@ -39,6 +39,17 @@ class Function(BaseModel):
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
class ChatMessage(BaseModel):
role: Role
content: str
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: list = []
tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True
temperature: Optional[float] = None
top_p: Optional[float] = None

View File

@ -193,6 +193,6 @@ def llama_flash_attn_forward(
def apply_llama_patch() -> None:
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1")
require_version("transformers==4.39.2", "To fix: pip install transformers==4.39.2")
LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward

View File

@ -331,7 +331,7 @@ def patch_model(
):
gen_config.do_sample = True
if model_args.resize_vocab:
if is_trainable and model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer)
if is_trainable:

View File

@ -15,7 +15,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
for grade, hour in zip(grades, hours):
total_score += grade_to_score[grade] * hour
total_hour += hour
return total_score / total_hour
return round(total_score / total_hour, 2)
def main():
@ -45,16 +45,19 @@ def main():
messages = []
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
if result.choices[0].message.tool_calls is None:
raise ValueError("Cannot retrieve function call from the response.")
messages.append(result.choices[0].message)
tool_call = result.choices[0].message.tool_calls[0].function
print(tool_call)
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
name, arguments = tool_call.name, json.loads(tool_call.arguments)
messages.append(
{"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
)
tool_result = tool_map[name](**arguments)
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
print(result.choices[0].message.content)
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
# Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
if __name__ == "__main__":