129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
import random
|
||
|
||
import torch
|
||
from gradio_client import Client
|
||
from knowledge_api import load_test_data
|
||
from knowledge_api import save_submissions_to_jsonl
|
||
from knowledge_api import detect_choice_in_answer
|
||
from langchain_api import get_knowledge
|
||
|
||
# 初始化客户端,假设你的API运行在本地的7860端口
|
||
client = Client("http://0.0.0.0:7860/")
|
||
|
||
|
||
# 与机器人对话
|
||
def chat_with_bot(user_message, kb):
|
||
chatbot_memory = []
|
||
try:
|
||
# 使用/append API追加用户消息到对话
|
||
result = client.predict(
|
||
chatbot=[],
|
||
role="user", # 角色为用户
|
||
query=user_message, # 用户的输入消息
|
||
api_name="/append" # 调用append API
|
||
)
|
||
|
||
chatbot_memory.append(("user", user_message))
|
||
torch.cuda.empty_cache()
|
||
|
||
if kb:
|
||
system_message = "根据已知信息选择唯一正确的选项,不需要解释原因,否则错误。\n"
|
||
else:
|
||
system_message = "选择唯一正确的选项,不需要解释原因。 \n"
|
||
|
||
response = client.predict(
|
||
chatbot=chatbot_memory,
|
||
system=system_message,
|
||
tools=None,
|
||
image=None,
|
||
max_new_tokens=512,
|
||
top_p=0.7,
|
||
temperature=0.9,
|
||
api_name="/stream" # 调用stream API
|
||
)
|
||
|
||
client.predict(
|
||
api_name="/lambda"
|
||
)
|
||
|
||
bot_reply = response[0] # 从response获取机器人的回复
|
||
chatbot_memory.append(("bot", bot_reply))
|
||
print(">>>>回答: " + bot_reply[1])
|
||
return bot_reply[1] # 返回机器人的回复
|
||
|
||
except Exception as e:
|
||
print(f"发生错误: {e}")
|
||
return "抱歉,我无法处理这个请求。"
|
||
|
||
|
||
# 主程序
|
||
if __name__ == '__main__':
|
||
# 加载测试数据
|
||
test_file = '/data/download/军事知识问答_round4.jsonl'
|
||
test_data = load_test_data(test_file)
|
||
submissions = []
|
||
possible_choices = ['A', 'B', 'C', 'D']
|
||
i = 0
|
||
for item in test_data:
|
||
i = i + 1
|
||
print(i)
|
||
task_id = item['task_id']
|
||
question = item['question'] + "\n"
|
||
questions = question
|
||
choices = item['choices']
|
||
|
||
# 构造问题
|
||
question += ''.join(choices)
|
||
'''
|
||
# 调用阿里云知识库
|
||
texts = retrieve_text(
|
||
index_id='5veynqe3rz',
|
||
query=question,
|
||
dense_similarity_top_k=100,
|
||
rerank_min_score=0.5,
|
||
enable_reranking=True,
|
||
rerank_top_n=1
|
||
)
|
||
'''
|
||
print(question)
|
||
# texts = get_knowledge(question)
|
||
texts = None
|
||
if texts:
|
||
knowledge = texts
|
||
print(knowledge)
|
||
answer = chat_with_bot(knowledge + question, True)
|
||
else:
|
||
answer = chat_with_bot(question, False)
|
||
|
||
# 检测答案
|
||
|
||
chosen_choice_letter = detect_choice_in_answer(answer)
|
||
|
||
if task_id == "c8837":
|
||
chosen_choice_letter = "A"
|
||
print("---------------------------------------------")
|
||
if chosen_choice_letter:
|
||
submission = {
|
||
"task_id": task_id,
|
||
"question": questions,
|
||
"choice": ''.join(choices),
|
||
"prompt": f"问题:\n{questions}\n要求:\n只选择唯一正确答案,不需要解释原因。\n选项:\n{''.join(choices)}答案:\n",
|
||
"raw_outputs": [chosen_choice_letter]
|
||
}
|
||
submissions.append(submission)
|
||
else:
|
||
# 如果未能从回答中检测出答案,则随机选择
|
||
chosen_choice_letter = random.choice(possible_choices)
|
||
submission = {
|
||
"task_id": task_id,
|
||
"question": questions,
|
||
"choice": ''.join(choices),
|
||
"prompt": f"问题:\n{questions}\n要求:\n只选择唯一正确答案,不需要解释原因。\n选项:\n{''.join(choices)}答案:\n",
|
||
"raw_outputs": [chosen_choice_letter]
|
||
}
|
||
submissions.append(submission)
|
||
|
||
# 保存提交结果
|
||
output_file_path = '/data/zksc/My-project/军事知识问答submissions4.jsonl'
|
||
save_submissions_to_jsonl(submissions, output_file_path)
|