Merge pull request #1 from mMrBun/main

Support conversation via API.
This commit is contained in:
hoshi-hiyouga 2023-05-30 16:34:00 +08:00 committed by GitHub
commit c36620ece4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 115 additions and 0 deletions

115
src/api.py Normal file
View File

@ -0,0 +1,115 @@
# coding=utf-8
# Chat with LLaMA in API mode.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Call:
# curl --location 'http://127.0.0.1:8000' \
# --header 'Content-Type: application/json' \
# --data '{"prompt": "Hello there!","history": []}'
# Response:
# {
# "response":"'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major
# is Computer Science and my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations
# Research & Management Sciences which means that when it comes to solving problems using computers or any kind data
# analysis; whether its from businesses , governments etc., i can help you out :) .'",
# "history":"[('Hello there!',
# 'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major is Computer Science and
# my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations Research & Management
# Sciences which means that when it comes to solving problems using computers or any kind data analysis; whether its
# from businesses , governments etc., i can help you out :) .')]",
# "status":200,
# "time":"2023-05-30 06:57:38" }
import datetime
import torch
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
import json
import uvicorn
from fastapi import FastAPI, Request
DEVICE = "cuda"
def torch_gc():
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
for device_id in range(num_gpus):
with torch.cuda.device(device_id):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
# Parse the request JSON
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
# Tokenize the input prompt
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(model.device)
# Generation arguments
gen_kwargs = {
"do_sample": True,
"top_p": 0.9,
"top_k": 40,
"temperature": 0.7,
"num_beams": 1,
"max_new_tokens": 256,
"repetition_penalty": 1.5
}
# Generate response
with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
# Update history
history = history + [(prompt, response)]
# Prepare response
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": repr(response),
"history": repr(history),
"status": 200,
"time": time
}
# Log and clean up
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
if __name__ == "__main__":
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)