Compatible with OpenAI API.

This commit is contained in:
Bun 2023-06-21 14:45:04 +08:00
parent ded5aa3c3d
commit 6e4db0903f
1 changed files with 96 additions and 35 deletions

View File

@ -15,17 +15,15 @@
import json
import datetime
import torch
import uvicorn
import datetime
from threading import Thread
from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
from transformers import TextIteratorStreamer
from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor
def torch_gc():
@ -40,61 +38,124 @@ def torch_gc():
app = FastAPI()
@app.post("/")
@app.post("/v1/chat/completions")
async def create_item(request: Request):
global model, tokenizer, prompt_template, source_prefix, generating_args
global model, tokenizer
# Parse the request JSON
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get("prompt")
history = json_post_list.get("history")
max_new_tokens = json_post_list.get("max_new_tokens", None)
top_p = json_post_list.get("top_p", None)
temperature = json_post_list.get("temperature", None)
prompt = json_post_raw.get("messages")[-1]["content"]
history = json_post_raw.get("messages")[:-1]
max_token = json_post_raw.get("max_tokens", None)
top_p = json_post_raw.get("top_p", None)
temperature = json_post_raw.get("temperature", None)
stream = check_stream(json_post_raw.get("stream"))
# Tokenize the input prompt
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")["input_ids"]
if stream:
generate = predict(prompt, max_token, top_p, temperature, history)
return StreamingResponse(generate, media_type="text/event-stream")
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[
"input_ids"]
input_ids = input_ids.to(model.device)
# Generation arguments
gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"]
gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"]
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
# Generate response
with torch.no_grad():
generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
# Update history
history = history + [(prompt, response)]
# Prepare response
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": repr(response),
"history": repr(history),
"status": 200,
"time": time
"choices": [
{
"message": {
"role": "assistant",
"content": response
}
}
]
}
# Log and clean up
log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\""
log = (
"["
+ time
+ "] "
+ "\", prompt:\""
+ prompt
+ "\", response:\""
+ repr(response)
+ "\""
)
print(log)
torch_gc()
return answer
if __name__ == "__main__":
def check_stream(stream):
if isinstance(stream, bool):
# stream 是布尔类型,直接使用
stream_value = stream
else:
# 不是布尔类型,尝试进行类型转换
if isinstance(stream, str):
stream = stream.lower()
if stream in ["true", "false"]:
# 使用字符串值转换为布尔值
stream_value = stream == "true"
else:
# 非法的字符串值
stream_value = False
else:
# 非布尔类型也非字符串类型
stream_value = False
return stream_value
async def predict(query, max_length, top_p, temperature, history):
global model, tokenizer
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
answer = {
"choices": [
{
"message": {
"role": "assistant",
"content": new_text
}
}
]
}
yield "data: " + json.dumps(answer) + '\n\n'
if __name__ == "__main__":
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)