support vllm
This commit is contained in:
parent
f74f804a71
commit
d07ad5cc1c
2
Makefile
2
Makefile
|
@ -1,6 +1,6 @@
|
||||||
.PHONY: quality style
|
.PHONY: quality style
|
||||||
|
|
||||||
check_dirs := scripts src
|
check_dirs := scripts src tests
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
|
|
19
README.md
19
README.md
|
@ -47,10 +47,11 @@ Choose your path:
|
||||||
|
|
||||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
||||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning.
|
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA.
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
|
||||||
## Benchmark
|
## Benchmark
|
||||||
|
|
||||||
|
@ -69,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
||||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
|
||||||
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage.
|
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage.
|
||||||
|
@ -79,7 +82,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||||
|
|
||||||
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
|
@ -553,7 +556,7 @@ deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
### Merge LoRA weights and export model
|
### Merge LoRA weights and export model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -574,7 +577,7 @@ python src/export_model.py \
|
||||||
### Inference with OpenAI-style API
|
### Inference with OpenAI-style API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -587,7 +590,7 @@ python src/api_demo.py \
|
||||||
### Inference with command line
|
### Inference with command line
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -597,7 +600,7 @@ python src/cli_demo.py \
|
||||||
### Inference with web browser
|
### Inference with web browser
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
|
17
README_zh.md
17
README_zh.md
|
@ -51,6 +51,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
- **先进算法**:DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
|
- **先进算法**:DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
|
||||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
|
||||||
## 性能指标
|
## 性能指标
|
||||||
|
|
||||||
|
@ -69,17 +70,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
||||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
|
||||||
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py`。
|
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py`。
|
||||||
|
|
||||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||||
|
|
||||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||||
|
|
||||||
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
|
@ -552,7 +555,7 @@ deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
### 合并 LoRA 权重并导出模型
|
### 合并 LoRA 权重并导出模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -573,7 +576,7 @@ python src/export_model.py \
|
||||||
### 使用 OpenAI 风格 API 推理
|
### 使用 OpenAI 风格 API 推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -586,7 +589,7 @@ python src/api_demo.py \
|
||||||
### 使用命令行推理
|
### 使用命令行推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
@ -596,7 +599,7 @@ python src/cli_demo.py \
|
||||||
### 使用浏览器推理
|
### 使用浏览器推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
- `pretrain.sh`
|
- `pretrain.sh`: do pre-train (optional)
|
||||||
- `sft.sh` -> `reward.sh` -> `ppo.sh`
|
- `sft.sh`: do supervised fine-tune
|
||||||
- `sft.sh` -> `dpo.sh` -> `predict.sh`
|
- `reward.sh`: do reward modeling (must after sft.sh)
|
||||||
|
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
|
||||||
|
- `dpo.sh`: do DPO training (must after sft.sh)
|
||||||
|
- `predict.sh`: do predict (must after sft.sh and dpo.sh)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
- `merge.sh` -> `quantize.sh`
|
- `merge.sh`: merge the lora weights
|
||||||
|
- `quantize.sh`: quantize the model with AutoGPTQ (must after merge.sh, optional)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
python ../../src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
--template default \
|
--template default \
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
python ../../src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
--model_name_or_path ../../models/llama2-7b-sft \
|
--model_name_or_path ../../models/llama2-7b-sft \
|
||||||
--template default \
|
--template default \
|
||||||
--export_dir ../../models/llama2-7b-sft-int4 \
|
--export_dir ../../models/llama2-7b-sft-int4 \
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
@ -73,7 +72,6 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
|
||||||
role_mapping = {
|
role_mapping = {
|
||||||
Role.USER: DataRole.USER.value,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
|
@ -89,7 +87,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if not chat_model.can_generate:
|
if not chat_model.engine.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
|
@ -121,20 +119,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
else:
|
else:
|
||||||
tools = ""
|
tools = ""
|
||||||
|
|
||||||
async with semaphore:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
|
|
||||||
|
|
||||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
if tools:
|
if tools:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
generate = stream_chat_completion(messages, system, tools, request)
|
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = await chat_model.achat(
|
||||||
messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
|
@ -148,7 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
if tools:
|
if tools:
|
||||||
result = chat_model.template.format_tools.extract(response.response_text)
|
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||||
else:
|
else:
|
||||||
result = response.response_text
|
result = response.response_text
|
||||||
|
|
||||||
|
@ -177,7 +170,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
def stream_chat_completion(
|
async def stream_chat_completion(
|
||||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
@ -186,7 +179,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
async for new_token in chat_model.astream_chat(
|
||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
|
@ -195,11 +188,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
):
|
):
|
||||||
if len(new_text) == 0:
|
if len(new_token) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
|
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
|
@ -213,18 +206,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||||
if chat_model.can_generate:
|
if chat_model.engine.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
async with semaphore:
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
return await loop.run_in_executor(None, get_score, request)
|
|
||||||
|
|
||||||
def get_score(request: ScoreEvaluationRequest):
|
|
||||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
|
from .base_engine import BaseEngine
|
||||||
from .chat_model import ChatModel
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ChatModel"]
|
__all__ = ["BaseEngine", "ChatModel"]
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..data import Template
|
||||||
|
from ..extras.packages import is_vllm_available
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import AsyncLLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
response_text: str
|
||||||
|
response_length: int
|
||||||
|
prompt_length: int
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
can_generate: bool
|
||||||
|
template: "Template"
|
||||||
|
generating_args: Dict[str, Any]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]: ...
|
|
@ -1,124 +1,50 @@
|
||||||
from dataclasses import dataclass
|
import asyncio
|
||||||
from threading import Thread
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
|
||||||
from ..extras.misc import get_logits_processor
|
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
from .hf_engine import HuggingfaceEngine
|
||||||
|
from .vllm_engine import VllmEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
if TYPE_CHECKING:
|
||||||
class Response:
|
from .base_engine import BaseEngine, Response
|
||||||
response_text: str
|
|
||||||
response_length: int
|
|
||||||
prompt_length: int
|
|
||||||
finish_reason: Literal["stop", "length"]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
if model_args.infer_backend == "hf":
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
elif model_args.infer_backend == "vllm":
|
||||||
)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
else:
|
||||||
self.model = dispatch_model(self.model)
|
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
|
||||||
|
|
||||||
def _process_args(
|
def _get_event_loop():
|
||||||
self,
|
try:
|
||||||
messages: Sequence[Dict[str, str]],
|
return asyncio.get_running_loop()
|
||||||
system: Optional[str] = None,
|
except RuntimeError:
|
||||||
tools: Optional[str] = None,
|
return asyncio.new_event_loop()
|
||||||
**input_kwargs,
|
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
|
||||||
prompt, _ = self.template.encode_oneturn(
|
|
||||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
|
||||||
)
|
|
||||||
prompt_length = len(prompt)
|
|
||||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
|
||||||
|
|
||||||
do_sample = input_kwargs.pop("do_sample", None)
|
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
|
||||||
top_p = input_kwargs.pop("top_p", None)
|
|
||||||
top_k = input_kwargs.pop("top_k", None)
|
|
||||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
|
||||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
|
||||||
|
|
||||||
generating_args = self.generating_args.to_dict()
|
|
||||||
generating_args.update(
|
|
||||||
dict(
|
|
||||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
|
||||||
temperature=temperature or generating_args["temperature"],
|
|
||||||
top_p=top_p or generating_args["top_p"],
|
|
||||||
top_k=top_k or generating_args["top_k"],
|
|
||||||
num_return_sequences=num_return_sequences or 1,
|
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
|
||||||
generating_args["do_sample"] = True
|
|
||||||
|
|
||||||
if max_length:
|
|
||||||
generating_args.pop("max_new_tokens", None)
|
|
||||||
generating_args["max_length"] = max_length
|
|
||||||
|
|
||||||
if max_new_tokens:
|
|
||||||
generating_args.pop("max_length", None)
|
|
||||||
generating_args["max_new_tokens"] = max_new_tokens
|
|
||||||
|
|
||||||
gen_kwargs = dict(
|
|
||||||
inputs=input_ids,
|
|
||||||
generation_config=GenerationConfig(**generating_args),
|
|
||||||
logits_processor=get_logits_processor(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[Response]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
loop = self._get_event_loop()
|
||||||
raise ValueError("The current model does not support `chat`.")
|
return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs))
|
||||||
|
|
||||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
async def achat(
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
self,
|
||||||
response_ids = generate_output[:, prompt_length:]
|
messages: Sequence[Dict[str, str]],
|
||||||
response = self.tokenizer.batch_decode(
|
system: Optional[str] = None,
|
||||||
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
tools: Optional[str] = None,
|
||||||
)
|
**input_kwargs,
|
||||||
results = []
|
) -> List["Response"]:
|
||||||
for i in range(len(response)):
|
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
|
||||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
|
||||||
results.append(
|
|
||||||
Response(
|
|
||||||
response_text=response[i],
|
|
||||||
response_length=response_length,
|
|
||||||
prompt_length=prompt_length,
|
|
||||||
finish_reason="stop" if len(eos_index) else "length",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
|
@ -126,44 +52,35 @@ class ChatModel:
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
if not self.can_generate:
|
loop = self._get_event_loop()
|
||||||
raise ValueError("The current model does not support `stream_chat`.")
|
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield loop.run_until_complete(generator.__anext__())
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
async def astream_chat(
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
self,
|
||||||
gen_kwargs["streamer"] = streamer
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||||
|
yield new_token
|
||||||
|
|
||||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
def get_scores(
|
||||||
thread.start()
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
loop = self._get_event_loop()
|
||||||
|
return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs))
|
||||||
|
|
||||||
yield from streamer
|
async def aget_scores(
|
||||||
|
self,
|
||||||
@torch.inference_mode()
|
batch_input: List[str],
|
||||||
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
**input_kwargs,
|
||||||
if self.can_generate:
|
) -> List[float]:
|
||||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||||
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
batch_input,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
|
||||||
return_tensors="pt",
|
|
||||||
add_special_tokens=True,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
|
||||||
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
|
||||||
|
|
||||||
if getattr(self.model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(input_ids.size(0)):
|
|
||||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
|
||||||
scores.append(values[i, end_index].nan_to_num().item())
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
|
@ -0,0 +1,261 @@
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
from threading import Thread
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras.misc import get_logits_processor
|
||||||
|
from ..model import load_model_and_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
|
from ..data import Template
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||||
|
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
|
)
|
||||||
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_args(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = template.encode_oneturn(
|
||||||
|
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
|
)
|
||||||
|
prompt_length = len(prompt_ids)
|
||||||
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
|
|
||||||
|
do_sample = input_kwargs.pop("do_sample", None)
|
||||||
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
generating_args.update(
|
||||||
|
dict(
|
||||||
|
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||||
|
temperature=temperature or generating_args["temperature"],
|
||||||
|
top_p=top_p or generating_args["top_p"],
|
||||||
|
top_k=top_k or generating_args["top_k"],
|
||||||
|
num_return_sequences=num_return_sequences or 1,
|
||||||
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
|
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||||
|
generating_args["do_sample"] = True
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
generating_args.pop("max_new_tokens", None)
|
||||||
|
generating_args["max_length"] = max_length
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
generating_args.pop("max_length", None)
|
||||||
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
gen_kwargs = dict(
|
||||||
|
inputs=inputs,
|
||||||
|
generation_config=GenerationConfig(**generating_args),
|
||||||
|
logits_processor=get_logits_processor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _chat(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> List["Response"]:
|
||||||
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
|
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||||
|
)
|
||||||
|
generate_output = model.generate(**gen_kwargs)
|
||||||
|
response_ids = generate_output[:, prompt_length:]
|
||||||
|
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
|
results = []
|
||||||
|
for i in range(len(response)):
|
||||||
|
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||||
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
|
results.append(
|
||||||
|
Response(
|
||||||
|
response_text=response[i],
|
||||||
|
response_length=response_length,
|
||||||
|
prompt_length=prompt_length,
|
||||||
|
finish_reason="stop" if len(eos_index) else "length",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _stream_chat(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Callable[[], str]:
|
||||||
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
|
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||||
|
)
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
gen_kwargs["streamer"] = streamer
|
||||||
|
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
try:
|
||||||
|
return streamer.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _get_scores(
|
||||||
|
model: "PreTrainedModelWrapper",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
batch_input: List[str],
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> List[float]:
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
|
inputs = tokenizer(
|
||||||
|
batch_input,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "chatglm":
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for i in range(input_ids.size(0)):
|
||||||
|
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
||||||
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
|
scores.append(values[i, end_index].nan_to_num().item())
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `chat`.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.template,
|
||||||
|
self.generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `stream_chat`.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.template,
|
||||||
|
self.generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
stream = self._stream_chat(*input_args)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await loop.run_in_executor(pool, stream)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
if self.can_generate:
|
||||||
|
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
|
@ -0,0 +1,144 @@
|
||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras.misc import get_device_count
|
||||||
|
from ..extras.packages import is_vllm_available
|
||||||
|
from ..model import load_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
class VllmEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model_args.model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=model_args.vllm_maxlen,
|
||||||
|
tensor_parallel_size=get_device_count(),
|
||||||
|
disable_log_stats=True,
|
||||||
|
disable_log_requests=True,
|
||||||
|
)
|
||||||
|
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
self.tokenizer = load_tokenizer(model_args)
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
|
async def _generate(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
|
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||||
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = self.template.encode_oneturn(
|
||||||
|
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
|
)
|
||||||
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
generating_args = self.generating_args.copy()
|
||||||
|
generating_args.update(
|
||||||
|
dict(
|
||||||
|
temperature=temperature or generating_args["temperature"],
|
||||||
|
top_p=top_p or generating_args["top_p"],
|
||||||
|
top_k=top_k or generating_args["top_k"],
|
||||||
|
num_return_sequences=num_return_sequences or 1,
|
||||||
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
generating_args["max_new_tokens"] = max_length - prompt_length
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=generating_args["num_return_sequences"],
|
||||||
|
repetition_penalty=generating_args["repetition_penalty"],
|
||||||
|
temperature=generating_args["temperature"],
|
||||||
|
top_p=generating_args["top_p"],
|
||||||
|
top_k=generating_args["top_k"],
|
||||||
|
use_beam_search=generating_args["num_beams"] > 1,
|
||||||
|
length_penalty=generating_args["length_penalty"],
|
||||||
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
|
max_tokens=generating_args["max_new_tokens"],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
result_generator = self.model.generate(
|
||||||
|
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
||||||
|
)
|
||||||
|
return result_generator
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]:
|
||||||
|
final_output = None
|
||||||
|
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||||
|
async for request_output in generator:
|
||||||
|
final_output = request_output
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for output in final_output.outputs:
|
||||||
|
results.append(
|
||||||
|
Response(
|
||||||
|
response_text=output.text,
|
||||||
|
response_length=len(output.token_ids),
|
||||||
|
prompt_length=len(final_output.prompt_token_ids),
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
generated_text = ""
|
||||||
|
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||||
|
async for result in generator:
|
||||||
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
|
generated_text = result.outputs[0].text
|
||||||
|
yield delta_text
|
||||||
|
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
raise NotImplementedError("vLLM engine does not support get_scores.")
|
|
@ -1,6 +1,6 @@
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import get_template_and_fix_tokenizer, templates
|
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||||
from .utils import Role, split_dataset
|
from .utils import Role, split_dataset
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Literal["default"] = "default"
|
tool_format: Optional[Literal["default"]] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
|
@ -83,12 +83,30 @@ class Formatter(ABC):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmptyFormatter(Formatter):
|
class EmptyFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_placeholder = False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||||
|
has_placeholder = True
|
||||||
|
|
||||||
|
if has_placeholder:
|
||||||
|
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
return self.slots
|
return self.slots
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StringFormatter(Formatter):
|
class StringFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_placeholder = False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||||
|
has_placeholder = True
|
||||||
|
|
||||||
|
if not has_placeholder:
|
||||||
|
raise ValueError("A placeholder is required in the string formatter.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for slot in self.slots:
|
||||||
|
@ -109,6 +127,17 @@ class StringFormatter(Formatter):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_name, has_args = False, False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if "{{name}}" in slot:
|
||||||
|
has_name = True
|
||||||
|
if "{{arguments}}" in slot:
|
||||||
|
has_args = True
|
||||||
|
|
||||||
|
if not has_name or not has_args:
|
||||||
|
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
|
@ -133,6 +162,10 @@ class FunctionFormatter(Formatter):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.tool_format is None:
|
||||||
|
raise ValueError("Tool format was not found.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -44,7 +44,7 @@ def load_single_dataset(
|
||||||
|
|
||||||
elif dataset_attr.load_from == "file":
|
elif dataset_attr.load_from == "file":
|
||||||
data_files = []
|
data_files = []
|
||||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
|
|
|
@ -19,13 +19,13 @@ class DatasetAttr:
|
||||||
|
|
||||||
""" basic configs """
|
""" basic configs """
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
dataset_name: Optional[str] = None
|
dataset_name: str
|
||||||
""" extra configs """
|
""" extra configs """
|
||||||
file_sha1: Optional[str] = None
|
file_sha1: Optional[str] = None
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
ranking: Optional[bool] = False
|
ranking: bool = False
|
||||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
""" columns """
|
""" columns """
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
""" columns for the alpaca format """
|
""" columns for the alpaca format """
|
||||||
|
|
|
@ -157,6 +157,12 @@ def get_current_device() -> torch.device:
|
||||||
|
|
||||||
|
|
||||||
def get_device_count() -> int:
|
def get_device_count() -> int:
|
||||||
|
r"""
|
||||||
|
Gets the number of available GPU devices.
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return 0
|
||||||
|
|
||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -51,3 +51,7 @@ def is_unsloth_available():
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def is_vllm_available():
|
||||||
|
return _is_package_available("vllm")
|
||||||
|
|
|
@ -16,35 +16,35 @@ class DataArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: str = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: str = field(
|
||||||
default="train",
|
default="train",
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."},
|
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||||
)
|
)
|
||||||
streaming: Optional[bool] = field(
|
streaming: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable dataset streaming."},
|
metadata={"help": "Enable dataset streaming."},
|
||||||
)
|
)
|
||||||
buffer_size: Optional[int] = field(
|
buffer_size: int = field(
|
||||||
default=16384,
|
default=16384,
|
||||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||||
)
|
)
|
||||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
|
@ -52,13 +52,13 @@ class DataArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -68,23 +68,23 @@ class DataArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: float = field(
|
||||||
default=0,
|
default=0.0,
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
||||||
)
|
)
|
||||||
sft_packing: Optional[bool] = field(
|
sft_packing: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
|
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
|
||||||
)
|
)
|
||||||
cache_path: Optional[str] = field(
|
cache_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save or load the preprocessed datasets."},
|
metadata={"help": "Path to save or load the pre-processed datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -14,23 +14,23 @@ class EvaluationArguments:
|
||||||
task: str = field(
|
task: str = field(
|
||||||
metadata={"help": "Name of the evaluation task."},
|
metadata={"help": "Name of the evaluation task."},
|
||||||
)
|
)
|
||||||
task_dir: Optional[str] = field(
|
task_dir: str = field(
|
||||||
default="evaluation",
|
default="evaluation",
|
||||||
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
||||||
)
|
)
|
||||||
batch_size: Optional[int] = field(
|
batch_size: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The batch size per GPU for evaluation."},
|
metadata={"help": "The batch size per GPU for evaluation."},
|
||||||
)
|
)
|
||||||
seed: Optional[int] = field(
|
seed: int = field(
|
||||||
default=42,
|
default=42,
|
||||||
metadata={"help": "Random seed to be used with data loaders."},
|
metadata={"help": "Random seed to be used with data loaders."},
|
||||||
)
|
)
|
||||||
lang: Optional[Literal["en", "zh"]] = field(
|
lang: Literal["en", "zh"] = field(
|
||||||
default="en",
|
default="en",
|
||||||
metadata={"help": "Language used at evaluation."},
|
metadata={"help": "Language used at evaluation."},
|
||||||
)
|
)
|
||||||
n_shot: Optional[int] = field(
|
n_shot: int = field(
|
||||||
default=5,
|
default=5,
|
||||||
metadata={"help": "Number of examplars for few-shot learning."},
|
metadata={"help": "Number of examplars for few-shot learning."},
|
||||||
)
|
)
|
||||||
|
@ -38,7 +38,7 @@ class EvaluationArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save the evaluation results."},
|
metadata={"help": "Path to save the evaluation results."},
|
||||||
)
|
)
|
||||||
download_mode: Optional[DownloadMode] = field(
|
download_mode: DownloadMode = field(
|
||||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
metadata={"help": "Download mode used for the evaluation datasets."},
|
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class FreezeArguments:
|
||||||
Others choices: the same as LLaMA."""
|
Others choices: the same as LLaMA."""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: int = field(
|
||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||||
)
|
)
|
||||||
|
@ -44,11 +44,11 @@ class LoraArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||||
)
|
)
|
||||||
lora_dropout: Optional[float] = field(
|
lora_dropout: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: int = field(
|
||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
|
@ -66,18 +66,19 @@ class LoraArguments:
|
||||||
Others choices: the same as LLaMA."""
|
Others choices: the same as LLaMA."""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_bf16_mode: Optional[bool] = field(
|
lora_bf16_mode: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
|
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
|
||||||
)
|
)
|
||||||
use_rslora: Optional[bool] = field(
|
use_rslora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||||
)
|
)
|
||||||
use_dora: Optional[bool] = field(
|
use_dora: bool = field(
|
||||||
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||||
)
|
)
|
||||||
create_new_adapter: Optional[bool] = field(
|
create_new_adapter: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||||
)
|
)
|
||||||
|
@ -89,23 +90,23 @@ class RLHFArguments:
|
||||||
Arguments pertaining to the PPO and DPO training.
|
Arguments pertaining to the PPO and DPO training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_beta: Optional[float] = field(
|
dpo_beta: float = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."},
|
metadata={"help": "The beta parameter for the DPO loss."},
|
||||||
)
|
)
|
||||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
|
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
|
||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."},
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
)
|
)
|
||||||
dpo_ftx: Optional[float] = field(
|
dpo_ftx: float = field(
|
||||||
default=0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
)
|
)
|
||||||
ppo_buffer_size: Optional[int] = field(
|
ppo_buffer_size: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_epochs: Optional[int] = field(
|
ppo_epochs: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
|
@ -113,15 +114,15 @@ class RLHFArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
ppo_score_norm: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use score normalization in PPO training."},
|
metadata={"help": "Use score normalization in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: float = field(
|
||||||
default=6.0,
|
default=6.0,
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_whiten_rewards: Optional[bool] = field(
|
ppo_whiten_rewards: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||||
)
|
)
|
||||||
|
@ -149,7 +150,7 @@ class RLHFArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."},
|
metadata={"help": "The number of bits to quantize the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
reward_model_type: Literal["lora", "full", "api"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
@ -161,19 +162,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."},
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."},
|
metadata={"help": "Which fine-tuning method to use."},
|
||||||
)
|
)
|
||||||
use_llama_pro: Optional[bool] = field(
|
use_llama_pro: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -8,41 +8,41 @@ class GeneratingArguments:
|
||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
do_sample: Optional[bool] = field(
|
do_sample: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
||||||
)
|
)
|
||||||
temperature: Optional[float] = field(
|
temperature: float = field(
|
||||||
default=0.95,
|
default=0.95,
|
||||||
metadata={"help": "The value used to modulate the next token probabilities."},
|
metadata={"help": "The value used to modulate the next token probabilities."},
|
||||||
)
|
)
|
||||||
top_p: Optional[float] = field(
|
top_p: float = field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
top_k: Optional[int] = field(
|
top_k: int = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||||
)
|
)
|
||||||
max_length: Optional[int] = field(
|
max_length: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||||
)
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||||
)
|
)
|
||||||
repetition_penalty: Optional[float] = field(
|
repetition_penalty: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
||||||
)
|
)
|
||||||
length_penalty: Optional[float] = field(
|
length_penalty: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -85,6 +85,10 @@ class ModelArguments:
|
||||||
default="hf",
|
default="hf",
|
||||||
metadata={"help": "Backend engine used at inference."},
|
metadata={"help": "Backend engine used at inference."},
|
||||||
)
|
)
|
||||||
|
vllm_maxlen: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Maximum input length of the vLLM engine."},
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
|
|
|
@ -9,8 +9,8 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_unsloth_available
|
|
||||||
from ..extras.misc import check_dependencies
|
from ..extras.misc import check_dependencies
|
||||||
|
from ..extras.packages import is_unsloth_available
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
|
@ -59,6 +59,9 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||||
|
|
||||||
|
|
||||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||||
|
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
@ -69,8 +72,18 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
if model_args.infer_backend == "vllm":
|
||||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
if finetuning_args.stage != "sft":
|
||||||
|
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
raise ValueError("vLLM engine does not support quantization.")
|
||||||
|
|
||||||
|
if model_args.rope_scaling is not None:
|
||||||
|
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
||||||
from .utils import dispatch_model, load_valuehead_params
|
from .utils import load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"load_model_and_tokenizer",
|
"load_model_and_tokenizer",
|
||||||
"load_tokenizer",
|
"load_tokenizer",
|
||||||
"dispatch_model",
|
|
||||||
"load_valuehead_params",
|
"load_valuehead_params",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import inspect
|
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -7,7 +6,6 @@ from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -19,36 +17,6 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|
||||||
r"""
|
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
|
||||||
"""
|
|
||||||
if getattr(model, "quantization_method", None): # already set on current device
|
|
||||||
return model
|
|
||||||
|
|
||||||
if (
|
|
||||||
torch.cuda.device_count() > 1
|
|
||||||
and isinstance(model, PreTrainedModel)
|
|
||||||
and model._no_split_modules is not None
|
|
||||||
and model.config.model_type != "chatglm"
|
|
||||||
):
|
|
||||||
from accelerate import dispatch_model
|
|
||||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
|
||||||
# Make sure tied weights are tied before creating the device map.
|
|
||||||
model.tie_weights()
|
|
||||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
|
||||||
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
|
||||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
|
||||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
|
||||||
return dispatch_model(model, **device_map_kwargs)
|
|
||||||
else:
|
|
||||||
return model.to(device=get_current_device())
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Finds all available modules to apply lora.
|
Finds all available modules to apply lora.
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
@ -7,12 +8,12 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..hparams import GeneratingArguments
|
|
||||||
from .common import get_save_dir
|
from .common import get_save_dir
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ..chat import BaseEngine
|
||||||
from .manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,29 +23,19 @@ class WebChatModel(ChatModel):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
self.model = None
|
self.engine: Optional["BaseEngine"] = None
|
||||||
self.tokenizer = None
|
|
||||||
self.generating_args = GeneratingArguments()
|
|
||||||
|
|
||||||
if not lazy_init: # read arguments from command line
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if demo_mode: # load demo_config.json if exists
|
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
|
||||||
import json
|
model_name_or_path = os.environ.get("DEMO_MODEL")
|
||||||
|
template = os.environ.get("DEMO_TEMPLATE")
|
||||||
try:
|
super().__init__(dict(model_name_or_path=model_name_or_path, template=template))
|
||||||
with open("demo_config.json", "r", encoding="utf-8") as f:
|
|
||||||
args = json.load(f)
|
|
||||||
assert args.get("model_name_or_path", None) and args.get("template", None)
|
|
||||||
super().__init__(args)
|
|
||||||
except AssertionError:
|
|
||||||
print("Please provided model name and template in `demo_config.json`.")
|
|
||||||
except Exception:
|
|
||||||
print("Cannot find `demo_config.json` at current directory.")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loaded(self) -> bool:
|
def loaded(self) -> bool:
|
||||||
return self.model is not None
|
return self.engine is not None
|
||||||
|
|
||||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||||
|
@ -98,8 +89,7 @@ class WebChatModel(ChatModel):
|
||||||
return
|
return
|
||||||
|
|
||||||
yield ALERTS["info_unloading"][lang]
|
yield ALERTS["info_unloading"][lang]
|
||||||
self.model = None
|
self.engine = None
|
||||||
self.tokenizer = None
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
yield ALERTS["info_unloaded"][lang]
|
yield ALERTS["info_unloaded"][lang]
|
||||||
|
|
||||||
|
@ -123,7 +113,7 @@ class WebChatModel(ChatModel):
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
if tools:
|
if tools:
|
||||||
result = self.template.format_tools.extract(response)
|
result = self.engine.template.format_tools.extract(response)
|
||||||
else:
|
else:
|
||||||
result = response
|
result = response
|
||||||
|
|
||||||
|
|
|
@ -28,10 +28,9 @@ def create_chat_box(
|
||||||
submit_btn = gr.Button(variant="primary")
|
submit_btn = gr.Button(variant="primary")
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
gen_kwargs = engine.chatter.generating_args
|
max_new_tokens = gr.Slider(8, 4096, value=512, step=1)
|
||||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01)
|
||||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
|
||||||
clear_btn = gr.Button()
|
clear_btn = gr.Button()
|
||||||
|
|
||||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||||
|
|
|
@ -16,8 +16,8 @@ class Engine:
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
self.pure_chat = pure_chat
|
self.pure_chat = pure_chat
|
||||||
self.manager = Manager()
|
self.manager = Manager()
|
||||||
self.runner = Runner(self.manager, demo_mode=demo_mode)
|
self.runner = Runner(self.manager, demo_mode)
|
||||||
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
|
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
|
||||||
|
|
||||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="0",
|
||||||
|
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": "Write a long essay about environment protection as long as possible."}]
|
||||||
|
num_tokens = 0
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(8):
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
num_tokens += result.usage.completion_tokens
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print("Throughput: {:.2f} tokens/s".format(num_tokens / elapsed_time))
|
||||||
|
# --infer_backend hf: 27.22 tokens/s (1.0x)
|
||||||
|
# --infer_backend vllm: 73.03 tokens/s (2.7x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -17,13 +18,10 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
|
||||||
return total_score / total_hour
|
return total_score / total_hour
|
||||||
|
|
||||||
|
|
||||||
tool_map = {"calculate_gpa": calculate_gpa}
|
def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key="0",
|
api_key="0",
|
||||||
base_url="http://localhost:8000/v1",
|
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||||
)
|
)
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -42,6 +40,8 @@ if __name__ == "__main__":
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
tool_map = {"calculate_gpa": calculate_gpa}
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
|
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
|
||||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||||
|
@ -55,3 +55,7 @@ if __name__ == "__main__":
|
||||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||||
print(result.choices[0].message.content)
|
print(result.choices[0].message.content)
|
||||||
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
|
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue