64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
|
import re
|
||
|
import requests
|
||
|
import json
|
||
|
|
||
|
|
||
|
def get_knowledge(query):
|
||
|
# 定义请求的URL
|
||
|
url = 'http://127.0.0.1:7861/chat/knowledge_base_chat'
|
||
|
|
||
|
# 定义请求头
|
||
|
headers = {
|
||
|
'accept': 'application/json',
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
|
||
|
# 定义请求的数据
|
||
|
data = {
|
||
|
"query": query,
|
||
|
"knowledge_base_name": "My_cx",
|
||
|
"top_k": 1,
|
||
|
"score_threshold": 0.7,
|
||
|
"history": [],
|
||
|
"stream": False,
|
||
|
"model_name": "fm9g_2b_hf_models",
|
||
|
"temperature": 0.8,
|
||
|
"max_tokens": 1,
|
||
|
"prompt_name": "empty"
|
||
|
}
|
||
|
|
||
|
json_data = json.dumps(data)
|
||
|
|
||
|
# 发送POST请求
|
||
|
try:
|
||
|
response = requests.post(url, headers=headers, data=json_data)
|
||
|
response.raise_for_status() # 确保请求成功
|
||
|
|
||
|
# 尝试解析JSON
|
||
|
if response.text:
|
||
|
start = response.text.find('{')
|
||
|
if start != -1:
|
||
|
json_text = response.text[start:]
|
||
|
response_json = json.loads(json_text)
|
||
|
|
||
|
docs = response_json.get("docs", [])
|
||
|
|
||
|
if docs and not any("未找到相关文档" in doc for doc in docs):
|
||
|
cleaned_docs = "<已知信息>\n"
|
||
|
for doc in docs:
|
||
|
# cleaned_doc = re.sub(r'出处 \[.*?\] \[.*?\]\(.*?\)', '', doc)
|
||
|
cleaned_doc = re.sub(r'\(.*?\)', '', doc)
|
||
|
cleaned_docs += cleaned_doc
|
||
|
cleaned_docs += "</已知信息>"
|
||
|
return cleaned_docs
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
except requests.exceptions.RequestException as e:
|
||
|
# 如果请求失败,返回错误信息
|
||
|
return f"请求失败: {e}"
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
print(get_knowledge("队列"))
|
||
|
print(get_knowledge("海湾战争"))
|