forked from p04798526/LLaMA-Factory-Mirror
Compatible with OpenAI API.
This commit is contained in:
parent
ded5aa3c3d
commit
6e4db0903f
131
src/api_demo.py
131
src/api_demo.py
|
@ -15,17 +15,15 @@
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import datetime
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import datetime
|
from threading import Thread
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from utils import (
|
from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor
|
||||||
Template,
|
|
||||||
load_pretrained,
|
|
||||||
prepare_infer_args,
|
|
||||||
get_logits_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
|
@ -40,61 +38,124 @@ def torch_gc():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
@app.post("/v1/chat/completions")
|
||||||
async def create_item(request: Request):
|
async def create_item(request: Request):
|
||||||
global model, tokenizer, prompt_template, source_prefix, generating_args
|
global model, tokenizer
|
||||||
|
|
||||||
# Parse the request JSON
|
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
json_post = json.dumps(json_post_raw)
|
prompt = json_post_raw.get("messages")[-1]["content"]
|
||||||
json_post_list = json.loads(json_post)
|
history = json_post_raw.get("messages")[:-1]
|
||||||
prompt = json_post_list.get("prompt")
|
max_token = json_post_raw.get("max_tokens", None)
|
||||||
history = json_post_list.get("history")
|
top_p = json_post_raw.get("top_p", None)
|
||||||
max_new_tokens = json_post_list.get("max_new_tokens", None)
|
temperature = json_post_raw.get("temperature", None)
|
||||||
top_p = json_post_list.get("top_p", None)
|
stream = check_stream(json_post_raw.get("stream"))
|
||||||
temperature = json_post_list.get("temperature", None)
|
|
||||||
|
|
||||||
# Tokenize the input prompt
|
if stream:
|
||||||
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")["input_ids"]
|
generate = predict(prompt, max_token, top_p, temperature, history)
|
||||||
|
return StreamingResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
|
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[
|
||||||
|
"input_ids"]
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
# Generation arguments
|
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
gen_kwargs["input_ids"] = input_ids
|
gen_kwargs["input_ids"] = input_ids
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"]
|
gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"]
|
||||||
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
|
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
|
||||||
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
|
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
|
||||||
|
|
||||||
# Generate response
|
generation_output = model.generate(**gen_kwargs)
|
||||||
with torch.no_grad():
|
|
||||||
generation_output = model.generate(**gen_kwargs)
|
|
||||||
outputs = generation_output.tolist()[0][len(input_ids[0]):]
|
outputs = generation_output.tolist()[0][len(input_ids[0]):]
|
||||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
# Update history
|
|
||||||
history = history + [(prompt, response)]
|
|
||||||
|
|
||||||
# Prepare response
|
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
answer = {
|
answer = {
|
||||||
"response": repr(response),
|
"choices": [
|
||||||
"history": repr(history),
|
{
|
||||||
"status": 200,
|
"message": {
|
||||||
"time": time
|
"role": "assistant",
|
||||||
|
"content": response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Log and clean up
|
log = (
|
||||||
log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\""
|
"["
|
||||||
|
+ time
|
||||||
|
+ "] "
|
||||||
|
+ "\", prompt:\""
|
||||||
|
+ prompt
|
||||||
|
+ "\", response:\""
|
||||||
|
+ repr(response)
|
||||||
|
+ "\""
|
||||||
|
)
|
||||||
print(log)
|
print(log)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def check_stream(stream):
|
||||||
|
if isinstance(stream, bool):
|
||||||
|
# stream 是布尔类型,直接使用
|
||||||
|
stream_value = stream
|
||||||
|
else:
|
||||||
|
# 不是布尔类型,尝试进行类型转换
|
||||||
|
if isinstance(stream, str):
|
||||||
|
stream = stream.lower()
|
||||||
|
if stream in ["true", "false"]:
|
||||||
|
# 使用字符串值转换为布尔值
|
||||||
|
stream_value = stream == "true"
|
||||||
|
else:
|
||||||
|
# 非法的字符串值
|
||||||
|
stream_value = False
|
||||||
|
else:
|
||||||
|
# 非布尔类型也非字符串类型
|
||||||
|
stream_value = False
|
||||||
|
return stream_value
|
||||||
|
|
||||||
|
|
||||||
|
async def predict(query, max_length, top_p, temperature, history):
|
||||||
|
global model, tokenizer
|
||||||
|
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
||||||
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
gen_kwargs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"do_sample": generating_args.do_sample,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature,
|
||||||
|
"num_beams": generating_args.num_beams,
|
||||||
|
"max_length": max_length,
|
||||||
|
"repetition_penalty": generating_args.repetition_penalty,
|
||||||
|
"logits_processor": get_logits_processor(),
|
||||||
|
"streamer": streamer
|
||||||
|
}
|
||||||
|
|
||||||
|
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
answer = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": new_text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
yield "data: " + json.dumps(answer) + '\n\n'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue