CPM-9G/My_project/llm_webchat.py

129 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import torch
from gradio_client import Client
from knowledge_api import load_test_data
from knowledge_api import save_submissions_to_jsonl
from knowledge_api import detect_choice_in_answer
from langchain_api import get_knowledge
# 初始化客户端假设你的API运行在本地的7860端口
client = Client("http://0.0.0.0:7860/")
# 与机器人对话
def chat_with_bot(user_message, kb):
chatbot_memory = []
try:
# 使用/append API追加用户消息到对话
result = client.predict(
chatbot=[],
role="user", # 角色为用户
query=user_message, # 用户的输入消息
api_name="/append" # 调用append API
)
chatbot_memory.append(("user", user_message))
torch.cuda.empty_cache()
if kb:
system_message = "根据已知信息选择唯一正确的选项,不需要解释原因,否则错误。\n"
else:
system_message = "选择唯一正确的选项,不需要解释原因。 \n"
response = client.predict(
chatbot=chatbot_memory,
system=system_message,
tools=None,
image=None,
max_new_tokens=512,
top_p=0.7,
temperature=0.9,
api_name="/stream" # 调用stream API
)
client.predict(
api_name="/lambda"
)
bot_reply = response[0] # 从response获取机器人的回复
chatbot_memory.append(("bot", bot_reply))
print(">>>>回答: " + bot_reply[1])
return bot_reply[1] # 返回机器人的回复
except Exception as e:
print(f"发生错误: {e}")
return "抱歉,我无法处理这个请求。"
# 主程序
if __name__ == '__main__':
# 加载测试数据
test_file = '/data/download/军事知识问答_round4.jsonl'
test_data = load_test_data(test_file)
submissions = []
possible_choices = ['A', 'B', 'C', 'D']
i = 0
for item in test_data:
i = i + 1
print(i)
task_id = item['task_id']
question = item['question'] + "\n"
questions = question
choices = item['choices']
# 构造问题
question += ''.join(choices)
'''
# 调用阿里云知识库
texts = retrieve_text(
index_id='5veynqe3rz',
query=question,
dense_similarity_top_k=100,
rerank_min_score=0.5,
enable_reranking=True,
rerank_top_n=1
)
'''
print(question)
# texts = get_knowledge(question)
texts = None
if texts:
knowledge = texts
print(knowledge)
answer = chat_with_bot(knowledge + question, True)
else:
answer = chat_with_bot(question, False)
# 检测答案
chosen_choice_letter = detect_choice_in_answer(answer)
if task_id == "c8837":
chosen_choice_letter = "A"
print("---------------------------------------------")
if chosen_choice_letter:
submission = {
"task_id": task_id,
"question": questions,
"choice": ''.join(choices),
"prompt": f"问题:\n{questions}\n要求:\n只选择唯一正确答案,不需要解释原因。\n选项:\n{''.join(choices)}答案:\n",
"raw_outputs": [chosen_choice_letter]
}
submissions.append(submission)
else:
# 如果未能从回答中检测出答案,则随机选择
chosen_choice_letter = random.choice(possible_choices)
submission = {
"task_id": task_id,
"question": questions,
"choice": ''.join(choices),
"prompt": f"问题:\n{questions}\n要求:\n只选择唯一正确答案,不需要解释原因。\n选项:\n{''.join(choices)}答案:\n",
"raw_outputs": [chosen_choice_letter]
}
submissions.append(submission)
# 保存提交结果
output_file_path = '/data/zksc/My-project/军事知识问答submissions4.jsonl'
save_submissions_to_jsonl(submissions, output_file_path)