94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
from gradio_client import Client
|
||
from knowledge_api import extract_code_between_triple_backticks
|
||
from knowledge_api import load_test_data
|
||
from knowledge_api import save_submissions_to_jsonl
|
||
import torch
|
||
|
||
# 初始化客户端,假设你的API运行在本地的7860端口
|
||
client = Client("http://0.0.0.0:7860/")
|
||
|
||
|
||
# 与机器人对话
|
||
def chat_with_bott(client, user_message, system_message="You are a code assistant, please directly generate a Python "
|
||
"code according to the requirement.\n"):
|
||
"""
|
||
You are a code helper,please directly generate a simple Python function without explanation
|
||
与机器人进行对话的函数,使用/stream API来获取响应。
|
||
"""
|
||
|
||
chatbot_memory = []
|
||
try:
|
||
|
||
# 使用/append API追加用户消息到对话
|
||
result = client.predict(
|
||
chatbot=[],
|
||
role="user", # 角色为用户
|
||
query=user_message, # 用户的输入消息
|
||
api_name="/append" # 调用append API
|
||
)
|
||
|
||
# 更新聊天记录,追加用户输入
|
||
chatbot_memory.append(("user", user_message))
|
||
torch.cuda.empty_cache()
|
||
|
||
# 使用/stream API获取机器人的回复
|
||
response = client.predict(
|
||
chatbot=chatbot_memory,
|
||
system=system_message,
|
||
tools=None,
|
||
image=None,
|
||
max_new_tokens=None,
|
||
top_p=0.7,
|
||
temperature=0.85,
|
||
api_name="/stream"
|
||
)
|
||
|
||
client.predict(
|
||
api_name="/lambda"
|
||
)
|
||
|
||
# 将机器人的回复添加到聊天记录中
|
||
bot_reply = response[0] # 从response获取机器人的回复
|
||
chatbot_memory.append(("bot", bot_reply))
|
||
print(bot_reply[1] + "\n--------------------")
|
||
return bot_reply[1] # 返回机器人的回复
|
||
|
||
except Exception as e:
|
||
print(f"发生错误: {e}")
|
||
return "抱歉,我无法处理这个请求。"
|
||
|
||
|
||
# 主程序
|
||
if __name__ == '__main__':
|
||
# 加载测试数据
|
||
pili_files = '/data/download/代码生成_round4.jsonl'
|
||
test_data = load_test_data(pili_files)
|
||
|
||
submissions = []
|
||
i = 0
|
||
for item in test_data:
|
||
task_id = item['task_id']
|
||
questions = item['question']
|
||
|
||
# 构造问题
|
||
question = f"{questions}\n>"
|
||
print(i)
|
||
i = i + 1
|
||
|
||
# 调用 chat_with_bot 函数来获得机器人的回答
|
||
answer = chat_with_bott(client, question)
|
||
|
||
the_answer = extract_code_between_triple_backticks(answer)
|
||
|
||
submission = {
|
||
"task_id": task_id,
|
||
"question": questions,
|
||
"prompt": f"{questions}\n",
|
||
"raw_outputs": the_answer
|
||
}
|
||
submissions.append(submission)
|
||
|
||
# 保存提交结果
|
||
output_file_path = '代码生成submissions4.jsonl'
|
||
save_submissions_to_jsonl(submissions, output_file_path)
|