update readme
This commit is contained in:
parent
614d3a996c
commit
0697643358
21
README.md
21
README.md
|
@ -9,11 +9,13 @@
|
|||
|
||||
## Changelog
|
||||
|
||||
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model.
|
||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.
|
||||
|
||||
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` and `--lora_target W_pack` arguments to use the baichuan-7B model.
|
||||
|
||||
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||
|
||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
|
||||
|
||||
## Supported Models
|
||||
|
||||
|
@ -75,9 +77,9 @@ huggingface-cli login
|
|||
|
||||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- protobuf, cpm_kernels and sentencepiece
|
||||
- jieba, rouge_chinese and nltk (used at evaluation)
|
||||
- gradio and mdtex2html (used in web_demo.py)
|
||||
- uvicorn and fastapi (used in api_demo.py)
|
||||
|
||||
And **powerful GPUs**!
|
||||
|
||||
|
@ -99,7 +101,7 @@ cd LLaMA-Efficient-Tuning
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### LLaMA Weights Preparation
|
||||
### LLaMA Weights Preparation (optional)
|
||||
|
||||
1. Download the weights of the LLaMA models.
|
||||
2. Convert them to HF format using the following command.
|
||||
|
@ -216,17 +218,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
|||
|
||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
|
||||
### CLI Demo
|
||||
### API / CLI / Web Demo
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
python src/xxx_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
torch>=1.13.1
|
||||
protobuf
|
||||
cpm_kernels
|
||||
sentencepiece
|
||||
transformers>=4.29.1
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.19.0
|
||||
|
@ -12,3 +9,5 @@ rouge_chinese
|
|||
nltk
|
||||
gradio
|
||||
mdtex2html
|
||||
uvicorn
|
||||
fastapi
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Implements API for fine-tuned models.
|
||||
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
|
||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
# Visit http://localhost:8000/docs for document.
|
||||
|
||||
|
@ -7,11 +7,10 @@
|
|||
import time
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from threading import Thread
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from contextlib import asynccontextmanager
|
||||
from transformers import TextIteratorStreamer
|
||||
from starlette.responses import StreamingResponse
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
object: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer, source_prefix
|
||||
global model, tokenizer, source_prefix, generating_args
|
||||
|
||||
if request.messages[-1].role != "user":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
source_prefix = prev_messages.pop(0).content
|
||||
prefix = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = source_prefix
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
|
@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
|
||||
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
|
@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
for new_text in streamer:
|
||||
|
@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
|
@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue