CPM-9G_Code/My_project/llm_web_code.py

94 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from gradio_client import Client
from knowledge_api import extract_code_between_triple_backticks
from knowledge_api import load_test_data
from knowledge_api import save_submissions_to_jsonl
import torch
# 初始化客户端假设你的API运行在本地的7860端口
client = Client("http://0.0.0.0:7860/")
# 与机器人对话
def chat_with_bott(client, user_message, system_message="You are a code assistant, please directly generate a Python "
"code according to the requirement.\n"):
"""
You are a code helperplease directly generate a simple Python function without explanation
与机器人进行对话的函数,使用/stream API来获取响应。
"""
chatbot_memory = []
try:
# 使用/append API追加用户消息到对话
result = client.predict(
chatbot=[],
role="user", # 角色为用户
query=user_message, # 用户的输入消息
api_name="/append" # 调用append API
)
# 更新聊天记录,追加用户输入
chatbot_memory.append(("user", user_message))
torch.cuda.empty_cache()
# 使用/stream API获取机器人的回复
response = client.predict(
chatbot=chatbot_memory,
system=system_message,
tools=None,
image=None,
max_new_tokens=None,
top_p=0.7,
temperature=0.85,
api_name="/stream"
)
client.predict(
api_name="/lambda"
)
# 将机器人的回复添加到聊天记录中
bot_reply = response[0] # 从response获取机器人的回复
chatbot_memory.append(("bot", bot_reply))
print(bot_reply[1] + "\n--------------------")
return bot_reply[1] # 返回机器人的回复
except Exception as e:
print(f"发生错误: {e}")
return "抱歉,我无法处理这个请求。"
# 主程序
if __name__ == '__main__':
# 加载测试数据
pili_files = '/data/download/代码生成_round4.jsonl'
test_data = load_test_data(pili_files)
submissions = []
i = 0
for item in test_data:
task_id = item['task_id']
questions = item['question']
# 构造问题
question = f"{questions}\n>"
print(i)
i = i + 1
# 调用 chat_with_bot 函数来获得机器人的回答
answer = chat_with_bott(client, question)
the_answer = extract_code_between_triple_backticks(answer)
submission = {
"task_id": task_id,
"question": questions,
"prompt": f"{questions}\n",
"raw_outputs": the_answer
}
submissions.append(submission)
# 保存提交结果
output_file_path = '代码生成submissions4.jsonl'
save_submissions_to_jsonl(submissions, output_file_path)