update readme

This commit is contained in:
hiyouga 2023-06-23 00:17:05 +08:00
parent 614d3a996c
commit 0697643358
3 changed files with 22 additions and 27 deletions

View File

@ -9,11 +9,13 @@
## Changelog
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model.
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` and `--lora_target W_pack` arguments to use the baichuan-7B model.
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
## Supported Models
@ -75,9 +77,9 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- protobuf, cpm_kernels and sentencepiece
- jieba, rouge_chinese and nltk (used at evaluation)
- gradio and mdtex2html (used in web_demo.py)
- uvicorn and fastapi (used in api_demo.py)
And **powerful GPUs**!
@ -99,7 +101,7 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
### LLaMA Weights Preparation
### LLaMA Weights Preparation (optional)
1. Download the weights of the LLaMA models.
2. Convert them to HF format using the following command.
@ -216,17 +218,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### CLI Demo
### API / CLI / Web Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
python src/xxx_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```

View File

@ -1,7 +1,4 @@
torch>=1.13.1
protobuf
cpm_kernels
sentencepiece
transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.19.0
@ -12,3 +9,5 @@ rouge_chinese
nltk
gradio
mdtex2html
uvicorn
fastapi

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Implements API for fine-tuned models.
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
@ -7,11 +7,10 @@
import time
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from threading import Thread
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
object: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
global model, tokenizer, source_prefix, generating_args
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content
prefix = prev_messages.pop(0).content
else:
prefix = source_prefix
history = []
if len(prev_messages) % 2 == 0:
@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict()
@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer:
@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))