CPM-9G/My_project/llm_webchat.py

129 lines
4.2 KiB
Python
Raw Normal View History

2024-10-31 21:52:29 +08:00
import random
import torch
from gradio_client import Client
from knowledge_api import load_test_data
from knowledge_api import save_submissions_to_jsonl
from knowledge_api import detect_choice_in_answer
from langchain_api import get_knowledge
# 初始化客户端假设你的API运行在本地的7860端口
client = Client("http://0.0.0.0:7860/")
# 与机器人对话
def chat_with_bot(user_message, kb):
chatbot_memory = []
try:
# 使用/append API追加用户消息到对话
result = client.predict(
chatbot=[],
role="user", # 角色为用户
query=user_message, # 用户的输入消息
api_name="/append" # 调用append API
)
chatbot_memory.append(("user", user_message))
torch.cuda.empty_cache()
if kb:
system_message = "根据已知信息选择唯一正确的选项,不需要解释原因,否则错误。\n"
else:
system_message = "选择唯一正确的选项,不需要解释原因。 \n"
response = client.predict(
chatbot=chatbot_memory,
system=system_message,
tools=None,
image=None,
max_new_tokens=512,
top_p=0.7,
temperature=0.9,
api_name="/stream" # 调用stream API
)
client.predict(
api_name="/lambda"
)
bot_reply = response[0] # 从response获取机器人的回复
chatbot_memory.append(("bot", bot_reply))
print(">>>>回答: " + bot_reply[1])
return bot_reply[1] # 返回机器人的回复
except Exception as e:
print(f"发生错误: {e}")
return "抱歉,我无法处理这个请求。"
# 主程序
if __name__ == '__main__':
# 加载测试数据
test_file = '/data/download/军事知识问答_round4.jsonl'
test_data = load_test_data(test_file)
submissions = []
possible_choices = ['A', 'B', 'C', 'D']
i = 0
for item in test_data:
i = i + 1
print(i)
task_id = item['task_id']
question = item['question'] + "\n"
questions = question
choices = item['choices']
# 构造问题
question += ''.join(choices)
'''
# 调用阿里云知识库
texts = retrieve_text(
index_id='5veynqe3rz',
query=question,
dense_similarity_top_k=100,
rerank_min_score=0.5,
enable_reranking=True,
rerank_top_n=1
)
'''
print(question)
# texts = get_knowledge(question)
texts = None
if texts:
knowledge = texts
print(knowledge)
answer = chat_with_bot(knowledge + question, True)
else:
answer = chat_with_bot(question, False)
# 检测答案
chosen_choice_letter = detect_choice_in_answer(answer)
if task_id == "c8837":
chosen_choice_letter = "A"
print("---------------------------------------------")
if chosen_choice_letter:
submission = {
"task_id": task_id,
"question": questions,
"choice": ''.join(choices),
"prompt": f"问题:\n{questions}\n要求:\n只选择唯一正确答案,不需要解释原因。\n选项:\n{''.join(choices)}答案:\n",
"raw_outputs": [chosen_choice_letter]
}
submissions.append(submission)
else:
# 如果未能从回答中检测出答案,则随机选择
chosen_choice_letter = random.choice(possible_choices)
submission = {
"task_id": task_id,
"question": questions,
"choice": ''.join(choices),
"prompt": f"问题:\n{questions}\n要求:\n只选择唯一正确答案,不需要解释原因。\n选项:\n{''.join(choices)}答案:\n",
"raw_outputs": [chosen_choice_letter]
}
submissions.append(submission)
# 保存提交结果
output_file_path = '/data/zksc/My-project/军事知识问答submissions4.jsonl'
save_submissions_to_jsonl(submissions, output_file_path)