diff --git a/src/api_demo.py b/src/api_demo.py index e31c8b9d..28125db7 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -15,17 +15,15 @@ import json +import datetime import torch import uvicorn -import datetime +from threading import Thread from fastapi import FastAPI, Request +from starlette.responses import StreamingResponse +from transformers import TextIteratorStreamer -from utils import ( - Template, - load_pretrained, - prepare_infer_args, - get_logits_processor -) +from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor def torch_gc(): @@ -40,61 +38,124 @@ def torch_gc(): app = FastAPI() -@app.post("/") +@app.post("/v1/chat/completions") async def create_item(request: Request): - global model, tokenizer, prompt_template, source_prefix, generating_args + global model, tokenizer - # Parse the request JSON json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get("prompt") - history = json_post_list.get("history") - max_new_tokens = json_post_list.get("max_new_tokens", None) - top_p = json_post_list.get("top_p", None) - temperature = json_post_list.get("temperature", None) + prompt = json_post_raw.get("messages")[-1]["content"] + history = json_post_raw.get("messages")[:-1] + max_token = json_post_raw.get("max_tokens", None) + top_p = json_post_raw.get("top_p", None) + temperature = json_post_raw.get("temperature", None) + stream = check_stream(json_post_raw.get("stream")) - # Tokenize the input prompt - input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")["input_ids"] + if stream: + generate = predict(prompt, max_token, top_p, temperature, history) + return StreamingResponse(generate, media_type="text/event-stream") + + input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[ + "input_ids"] input_ids = input_ids.to(model.device) - # Generation arguments gen_kwargs = generating_args.to_dict() gen_kwargs["input_ids"] = input_ids gen_kwargs["logits_processor"] = get_logits_processor() - gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"] + gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"] gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"] gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"] - # Generate response - with torch.no_grad(): - generation_output = model.generate(**gen_kwargs) + generation_output = model.generate(**gen_kwargs) + outputs = generation_output.tolist()[0][len(input_ids[0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) - # Update history - history = history + [(prompt, response)] - - # Prepare response now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") answer = { - "response": repr(response), - "history": repr(history), - "status": 200, - "time": time + "choices": [ + { + "message": { + "role": "assistant", + "content": response + } + } + ] } - # Log and clean up - log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\"" + log = ( + "[" + + time + + "] " + + "\", prompt:\"" + + prompt + + "\", response:\"" + + repr(response) + + "\"" + ) print(log) torch_gc() return answer -if __name__ == "__main__": +def check_stream(stream): + if isinstance(stream, bool): + # stream 是布尔类型,直接使用 + stream_value = stream + else: + # 不是布尔类型,尝试进行类型转换 + if isinstance(stream, str): + stream = stream.lower() + if stream in ["true", "false"]: + # 使用字符串值转换为布尔值 + stream_value = stream == "true" + else: + # 非法的字符串值 + stream_value = False + else: + # 非布尔类型也非字符串类型 + stream_value = False + return stream_value + +async def predict(query, max_length, top_p, temperature, history): + global model, tokenizer + input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] + input_ids = input_ids.to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + + gen_kwargs = { + "input_ids": input_ids, + "do_sample": generating_args.do_sample, + "top_p": top_p, + "temperature": temperature, + "num_beams": generating_args.num_beams, + "max_length": max_length, + "repetition_penalty": generating_args.repetition_penalty, + "logits_processor": get_logits_processor(), + "streamer": streamer + } + + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + + for new_text in streamer: + answer = { + "choices": [ + { + "message": { + "role": "assistant", + "content": new_text + } + } + ] + } + yield "data: " + json.dumps(answer) + '\n\n' + + +if __name__ == "__main__": model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args)