Merge pull request #68 from mMrBun/main

Compatible with OpenAI API.
This commit is contained in:
hoshi-hiyouga 2023-06-22 15:52:34 +08:00 committed by GitHub
commit 9324940b76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 96 additions and 35 deletions

View File

@ -15,17 +15,15 @@
import json import json
import datetime
import torch import torch
import uvicorn import uvicorn
import datetime from threading import Thread
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
from transformers import TextIteratorStreamer
from utils import ( from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
def torch_gc(): def torch_gc():
@ -40,61 +38,124 @@ def torch_gc():
app = FastAPI() app = FastAPI()
@app.post("/") @app.post("/v1/chat/completions")
async def create_item(request: Request): async def create_item(request: Request):
global model, tokenizer, prompt_template, source_prefix, generating_args global model, tokenizer
# Parse the request JSON
json_post_raw = await request.json() json_post_raw = await request.json()
json_post = json.dumps(json_post_raw) prompt = json_post_raw.get("messages")[-1]["content"]
json_post_list = json.loads(json_post) history = json_post_raw.get("messages")[:-1]
prompt = json_post_list.get("prompt") max_token = json_post_raw.get("max_tokens", None)
history = json_post_list.get("history") top_p = json_post_raw.get("top_p", None)
max_new_tokens = json_post_list.get("max_new_tokens", None) temperature = json_post_raw.get("temperature", None)
top_p = json_post_list.get("top_p", None) stream = check_stream(json_post_raw.get("stream"))
temperature = json_post_list.get("temperature", None)
# Tokenize the input prompt if stream:
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")["input_ids"] generate = predict(prompt, max_token, top_p, temperature, history)
return StreamingResponse(generate, media_type="text/event-stream")
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[
"input_ids"]
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
# Generation arguments
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"] gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"]
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"] gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"] gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
# Generate response generation_output = model.generate(**gen_kwargs)
with torch.no_grad():
generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(input_ids[0]):] outputs = generation_output.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
# Update history
history = history + [(prompt, response)]
# Prepare response
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
"response": repr(response), "choices": [
"history": repr(history), {
"status": 200, "message": {
"time": time "role": "assistant",
"content": response
}
}
]
} }
# Log and clean up log = (
log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\"" "["
+ time
+ "] "
+ "\", prompt:\""
+ prompt
+ "\", response:\""
+ repr(response)
+ "\""
)
print(log) print(log)
torch_gc() torch_gc()
return answer return answer
if __name__ == "__main__": def check_stream(stream):
if isinstance(stream, bool):
# stream 是布尔类型,直接使用
stream_value = stream
else:
# 不是布尔类型,尝试进行类型转换
if isinstance(stream, str):
stream = stream.lower()
if stream in ["true", "false"]:
# 使用字符串值转换为布尔值
stream_value = stream == "true"
else:
# 非法的字符串值
stream_value = False
else:
# 非布尔类型也非字符串类型
stream_value = False
return stream_value
async def predict(query, max_length, top_p, temperature, history):
global model, tokenizer
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
answer = {
"choices": [
{
"message": {
"role": "assistant",
"content": new_text
}
}
]
}
yield "data: " + json.dumps(answer) + '\n\n'
if __name__ == "__main__":
model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)