support vllm

This commit is contained in:
hiyouga 2024-03-07 20:26:31 +08:00
parent f74f804a71
commit d07ad5cc1c
32 changed files with 752 additions and 316 deletions

View File

@ -1,6 +1,6 @@
.PHONY: quality style .PHONY: quality style
check_dirs := scripts src check_dirs := scripts src tests
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)

View File

@ -47,10 +47,11 @@ Choose your path:
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning. - **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
## Benchmark ## Benchmark
@ -69,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage. [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage.
@ -79,7 +82,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`. [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
@ -553,7 +556,7 @@ deepspeed --num_gpus 8 src/train_bash.py \
### Merge LoRA weights and export model ### Merge LoRA weights and export model
```bash ```bash
python src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
@ -574,7 +577,7 @@ python src/export_model.py \
### Inference with OpenAI-style API ### Inference with OpenAI-style API
```bash ```bash
python src/api_demo.py \ CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
@ -587,7 +590,7 @@ python src/api_demo.py \
### Inference with command line ### Inference with command line
```bash ```bash
python src/cli_demo.py \ CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
@ -597,7 +600,7 @@ python src/cli_demo.py \
### Inference with web browser ### Inference with web browser
```bash ```bash
python src/web_demo.py \ CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \

View File

@ -51,6 +51,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **先进算法**DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。 - **先进算法**DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
## 性能指标 ## 性能指标
@ -69,17 +70,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py` [24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py`
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。 [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
@ -552,7 +555,7 @@ deepspeed --num_gpus 8 src/train_bash.py \
### 合并 LoRA 权重并导出模型 ### 合并 LoRA 权重并导出模型
```bash ```bash
python src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
@ -573,7 +576,7 @@ python src/export_model.py \
### 使用 OpenAI 风格 API 推理 ### 使用 OpenAI 风格 API 推理
```bash ```bash
python src/api_demo.py \ CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
@ -586,7 +589,7 @@ python src/api_demo.py \
### 使用命令行推理 ### 使用命令行推理
```bash ```bash
python src/cli_demo.py \ CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \
@ -596,7 +599,7 @@ python src/cli_demo.py \
### 使用浏览器推理 ### 使用浏览器推理
```bash ```bash
python src/web_demo.py \ CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
--model_name_or_path path_to_llama_model \ --model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \ --adapter_name_or_path path_to_checkpoint \
--template default \ --template default \

View File

@ -1,5 +1,8 @@
Usage: Usage:
- `pretrain.sh` - `pretrain.sh`: do pre-train (optional)
- `sft.sh` -> `reward.sh` -> `ppo.sh` - `sft.sh`: do supervised fine-tune
- `sft.sh` -> `dpo.sh` -> `predict.sh` - `reward.sh`: do reward modeling (must after sft.sh)
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
- `dpo.sh`: do DPO training (must after sft.sh)
- `predict.sh`: do predict (must after sft.sh and dpo.sh)

View File

@ -1,3 +1,4 @@
Usage: Usage:
- `merge.sh` -> `quantize.sh` - `merge.sh`: merge the lora weights
- `quantize.sh`: quantize the model with AutoGPTQ (must after merge.sh, optional)

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \ --template default \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path ../../models/llama2-7b-sft \ --model_name_or_path ../../models/llama2-7b-sft \
--template default \ --template default \
--export_dir ../../models/llama2-7b-sft-int4 \ --export_dir ../../models/llama2-7b-sft-int4 \

View File

@ -1,4 +1,3 @@
import asyncio
import json import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -73,7 +72,6 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_headers=["*"], allow_headers=["*"],
) )
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
role_mapping = { role_mapping = {
Role.USER: DataRole.USER.value, Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value, Role.ASSISTANT: DataRole.ASSISTANT.value,
@ -89,7 +87,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.can_generate: if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0: if len(request.messages) == 0:
@ -121,20 +119,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
else: else:
tools = "" tools = ""
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(messages, system, tools, request) generate = stream_chat_completion(input_messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat( responses = await chat_model.achat(
messages, input_messages,
system, system,
tools, tools,
do_sample=request.do_sample, do_sample=request.do_sample,
@ -148,7 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
if tools: if tools:
result = chat_model.template.format_tools.extract(response.response_text) result = chat_model.engine.template.format_tools.extract(response.response_text)
else: else:
result = response.response_text result = response.response_text
@ -177,7 +170,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
def stream_chat_completion( async def stream_chat_completion(
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
): ):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
@ -186,7 +179,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk) yield jsonify(chunk)
for new_text in chat_model.stream_chat( async for new_token in chat_model.astream_chat(
messages, messages,
system, system,
tools, tools,
@ -195,11 +188,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
): ):
if len(new_text) == 0: if len(new_token) == 0:
continue continue
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk) yield jsonify(chunk)
@ -213,18 +206,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK) @app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
async def create_score_evaluation(request: ScoreEvaluationRequest): async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.can_generate: if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore: scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(model=request.model, scores=scores)
return app return app

View File

@ -1,4 +1,5 @@
from .base_engine import BaseEngine
from .chat_model import ChatModel from .chat_model import ChatModel
__all__ = ["ChatModel"] __all__ = ["BaseEngine", "ChatModel"]

View File

@ -0,0 +1,64 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class BaseEngine(ABC):
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
template: "Template"
generating_args: Dict[str, Any]
@abstractmethod
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None: ...
@abstractmethod
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]: ...
@abstractmethod
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]: ...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]: ...

View File

@ -1,124 +1,50 @@
from dataclasses import dataclass import asyncio
from threading import Thread from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..hparams import get_infer_args from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine
@dataclass if TYPE_CHECKING:
class Response: from .base_engine import BaseEngine, Response
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel: class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.can_generate = finetuning_args.stage == "sft" if model_args.infer_backend == "hf":
self.model, self.tokenizer = load_model_and_tokenizer( self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) elif model_args.infer_backend == "vllm":
) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
self.tokenizer.padding_side = "left" if self.can_generate else "right" else:
self.model = dispatch_model(self.model) raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
def _process_args( def _get_event_loop():
self, try:
messages: Sequence[Dict[str, str]], return asyncio.get_running_loop()
system: Optional[str] = None, except RuntimeError:
tools: Optional[str] = None, return asyncio.new_event_loop()
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict()
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
return gen_kwargs, prompt_length
@torch.inference_mode()
def chat( def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> List[Response]: ) -> List["Response"]:
if not self.can_generate: loop = self._get_event_loop()
raise ValueError("The current model does not support `chat`.") return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs))
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs) async def achat(
generate_output = self.model.generate(**gen_kwargs) self,
response_ids = generate_output[:, prompt_length:] messages: Sequence[Dict[str, str]],
response = self.tokenizer.batch_decode( system: Optional[str] = None,
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True tools: Optional[str] = None,
) **input_kwargs,
results = [] ) -> List["Response"]:
for i in range(len(response)): return await self.engine.chat(messages, system, tools, **input_kwargs)
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@torch.inference_mode()
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -126,44 +52,35 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if not self.can_generate: loop = self._get_event_loop()
raise ValueError("The current model does not support `stream_chat`.") generator = self.astream_chat(messages, system, tools, **input_kwargs)
while True:
try:
yield loop.run_until_complete(generator.__anext__())
except StopAsyncIteration:
break
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs) async def astream_chat(
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) self,
gen_kwargs["streamer"] = streamer messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
yield new_token
thread = Thread(target=self.model.generate, kwargs=gen_kwargs) def get_scores(
thread.start() self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
loop = self._get_event_loop()
return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs))
yield from streamer async def aget_scores(
self,
@torch.inference_mode() batch_input: List[str],
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]: **input_kwargs,
if self.can_generate: ) -> List[float]:
raise ValueError("Cannot get scores using an auto-regressive model.") return await self.engine.get_scores(batch_input, **input_kwargs)
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(self.model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores

View File

@ -0,0 +1,261 @@
import asyncio
import concurrent.futures
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from trl import PreTrainedModelWrapper
from ..data import Template
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
class HuggingfaceEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
@staticmethod
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=inputs,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
return gen_kwargs, prompt_length
@staticmethod
@torch.inference_mode()
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@staticmethod
@torch.inference_mode()
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
max_length = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
self.template,
self.generating_args,
messages,
system,
tools,
input_kwargs,
)
async with self._semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
self.template,
self.generating_args,
messages,
system,
tools,
input_kwargs,
)
async with self._semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
try:
yield await loop.run_in_executor(pool, stream)
except StopAsyncIteration:
break
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self._semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)

View File

@ -0,0 +1,144 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..model import load_tokenizer
from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
class VllmEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,
trust_remote_code=True,
max_model_len=model_args.vllm_maxlen,
tensor_parallel_size=get_device_count(),
disable_log_stats=True,
disable_log_requests=True,
)
self.model = AsyncLLMEngine.from_engine_args(engine_args)
self.tokenizer = load_tokenizer(model_args)
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_length = len(prompt_ids)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.copy()
generating_args.update(
dict(
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
)
)
if max_length:
generating_args["max_new_tokens"] = max_length - prompt_length
if max_new_tokens:
generating_args["max_new_tokens"] = max_new_tokens
sampling_params = SamplingParams(
n=generating_args["num_return_sequences"],
repetition_penalty=generating_args["repetition_penalty"],
temperature=generating_args["temperature"],
top_p=generating_args["top_p"],
top_k=generating_args["top_k"],
use_beam_search=generating_args["num_beams"] > 1,
length_penalty=generating_args["length_penalty"],
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True,
)
result_generator = self.model.generate(
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
)
return result_generator
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, **input_kwargs)
async for request_output in generator:
final_output = request_output
results = []
for output in final_output.outputs:
results.append(
Response(
response_text=output.text,
response_length=len(output.token_ids),
prompt_length=len(final_output.prompt_token_ids),
finish_reason=output.finish_reason,
)
)
return results
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")

View File

@ -1,6 +1,6 @@
from .loader import get_dataset from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset from .utils import Role, split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"] __all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]

View File

@ -2,7 +2,7 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default" tool_format: Optional[Literal["default"]] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS: ...
@ -83,12 +83,30 @@ class Formatter(ABC):
@dataclass @dataclass
class EmptyFormatter(Formatter): class EmptyFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
return self.slots return self.slots
@dataclass @dataclass
class StringFormatter(Formatter): class StringFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
elements = [] elements = []
for slot in self.slots: for slot in self.slots:
@ -109,6 +127,17 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self):
has_name, has_args = False, False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if "{{name}}" in slot:
has_name = True
if "{{arguments}}" in slot:
has_args = True
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
@ -133,6 +162,10 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format is None:
raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:

View File

@ -44,7 +44,7 @@ def load_single_dataset(
elif dataset_attr.load_from == "file": elif dataset_attr.load_from == "file":
data_files = [] data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))

View File

@ -19,13 +19,13 @@ class DatasetAttr:
""" basic configs """ """ basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None dataset_name: str
""" extra configs """ """ extra configs """
file_sha1: Optional[str] = None file_sha1: Optional[str] = None
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: Optional[bool] = False ranking: bool = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """ """ columns """
system: Optional[str] = None system: Optional[str] = None
""" columns for the alpaca format """ """ columns for the alpaca format """

View File

@ -157,6 +157,12 @@ def get_current_device() -> torch.device:
def get_device_count() -> int: def get_device_count() -> int:
r"""
Gets the number of available GPU devices.
"""
if not torch.cuda.is_available():
return 0
return torch.cuda.device_count() return torch.cuda.device_count()

View File

@ -51,3 +51,7 @@ def is_unsloth_available():
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")
def is_vllm_available():
return _is_package_available("vllm")

View File

@ -16,35 +16,35 @@ class DataArguments:
default=None, default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
) )
dataset_dir: Optional[str] = field( dataset_dir: str = field(
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
split: Optional[str] = field( split: str = field(
default="train", default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}, metadata={"help": "Which dataset split to use for training and evaluation."},
) )
cutoff_len: Optional[int] = field( cutoff_len: int = field(
default=1024, default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."}, metadata={"help": "The cutoff length of the model inputs after tokenization."},
) )
reserved_label_len: Optional[int] = field( reserved_label_len: int = field(
default=1, default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
) )
train_on_prompt: Optional[bool] = field( train_on_prompt: bool = field(
default=False, default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}, metadata={"help": "Whether to disable the mask on the prompt or not."},
) )
streaming: Optional[bool] = field( streaming: bool = field(
default=False, default=False,
metadata={"help": "Enable dataset streaming."}, metadata={"help": "Enable dataset streaming."},
) )
buffer_size: Optional[int] = field( buffer_size: int = field(
default=16384, default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
) )
@ -52,13 +52,13 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
overwrite_cache: Optional[bool] = field( overwrite_cache: bool = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}, metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the pre-processing."},
) )
max_samples: Optional[int] = field( max_samples: Optional[int] = field(
default=None, default=None,
@ -68,23 +68,23 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
ignore_pad_token_for_loss: Optional[bool] = field( ignore_pad_token_for_loss: bool = field(
default=True, default=True,
metadata={ metadata={
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
}, },
) )
val_size: Optional[float] = field( val_size: float = field(
default=0, default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
) )
sft_packing: Optional[bool] = field( sft_packing: bool = field(
default=False, default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
) )
cache_path: Optional[str] = field( cache_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save or load the preprocessed datasets."}, metadata={"help": "Path to save or load the pre-processed datasets."},
) )
def __post_init__(self): def __post_init__(self):

View File

@ -14,23 +14,23 @@ class EvaluationArguments:
task: str = field( task: str = field(
metadata={"help": "Name of the evaluation task."}, metadata={"help": "Name of the evaluation task."},
) )
task_dir: Optional[str] = field( task_dir: str = field(
default="evaluation", default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."}, metadata={"help": "Path to the folder containing the evaluation datasets."},
) )
batch_size: Optional[int] = field( batch_size: int = field(
default=4, default=4,
metadata={"help": "The batch size per GPU for evaluation."}, metadata={"help": "The batch size per GPU for evaluation."},
) )
seed: Optional[int] = field( seed: int = field(
default=42, default=42,
metadata={"help": "Random seed to be used with data loaders."}, metadata={"help": "Random seed to be used with data loaders."},
) )
lang: Optional[Literal["en", "zh"]] = field( lang: Literal["en", "zh"] = field(
default="en", default="en",
metadata={"help": "Language used at evaluation."}, metadata={"help": "Language used at evaluation."},
) )
n_shot: Optional[int] = field( n_shot: int = field(
default=5, default=5,
metadata={"help": "Number of examplars for few-shot learning."}, metadata={"help": "Number of examplars for few-shot learning."},
) )
@ -38,7 +38,7 @@ class EvaluationArguments:
default=None, default=None,
metadata={"help": "Path to save the evaluation results."}, metadata={"help": "Path to save the evaluation results."},
) )
download_mode: Optional[DownloadMode] = field( download_mode: DownloadMode = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS, default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."}, metadata={"help": "Download mode used for the evaluation datasets."},
) )

View File

@ -22,7 +22,7 @@ class FreezeArguments:
Others choices: the same as LLaMA.""" Others choices: the same as LLaMA."""
}, },
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: int = field(
default=3, default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
) )
@ -44,11 +44,11 @@ class LoraArguments:
default=None, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
) )
lora_dropout: Optional[float] = field( lora_dropout: float = field(
default=0.0, default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}, metadata={"help": "Dropout rate for the LoRA fine-tuning."},
) )
lora_rank: Optional[int] = field( lora_rank: int = field(
default=8, default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
) )
@ -66,18 +66,19 @@ class LoraArguments:
Others choices: the same as LLaMA.""" Others choices: the same as LLaMA."""
}, },
) )
lora_bf16_mode: Optional[bool] = field( lora_bf16_mode: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."}, metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
) )
use_rslora: Optional[bool] = field( use_rslora: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
) )
use_dora: Optional[bool] = field( use_dora: bool = field(
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."} default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
) )
create_new_adapter: Optional[bool] = field( create_new_adapter: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
) )
@ -89,23 +90,23 @@ class RLHFArguments:
Arguments pertaining to the PPO and DPO training. Arguments pertaining to the PPO and DPO training.
""" """
dpo_beta: Optional[float] = field( dpo_beta: float = field(
default=0.1, default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}, metadata={"help": "The beta parameter for the DPO loss."},
) )
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field( dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
default="sigmoid", default="sigmoid",
metadata={"help": "The type of DPO loss to use."}, metadata={"help": "The type of DPO loss to use."},
) )
dpo_ftx: Optional[float] = field( dpo_ftx: float = field(
default=0, default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
) )
ppo_buffer_size: Optional[int] = field( ppo_buffer_size: int = field(
default=1, default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
) )
ppo_epochs: Optional[int] = field( ppo_epochs: int = field(
default=4, default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."}, metadata={"help": "The number of epochs to perform in a PPO optimization step."},
) )
@ -113,15 +114,15 @@ class RLHFArguments:
default=None, default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
) )
ppo_score_norm: Optional[bool] = field( ppo_score_norm: bool = field(
default=False, default=False,
metadata={"help": "Use score normalization in PPO training."}, metadata={"help": "Use score normalization in PPO training."},
) )
ppo_target: Optional[float] = field( ppo_target: float = field(
default=6.0, default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."}, metadata={"help": "Target KL value for adaptive KL control in PPO training."},
) )
ppo_whiten_rewards: Optional[bool] = field( ppo_whiten_rewards: bool = field(
default=False, default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
) )
@ -149,7 +150,7 @@ class RLHFArguments:
default=None, default=None,
metadata={"help": "The number of bits to quantize the reward model."}, metadata={"help": "The number of bits to quantize the reward model."},
) )
reward_model_type: Optional[Literal["lora", "full", "api"]] = field( reward_model_type: Literal["lora", "full", "api"] = field(
default="lora", default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
) )
@ -161,19 +162,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},
) )
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( finetuning_type: Literal["lora", "freeze", "full"] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."}, metadata={"help": "Which fine-tuning method to use."},
) )
use_llama_pro: Optional[bool] = field( use_llama_pro: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
) )
plot_loss: Optional[bool] = field( plot_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},
) )

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional from typing import Any, Dict
@dataclass @dataclass
@ -8,41 +8,41 @@ class GeneratingArguments:
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
) )
temperature: Optional[float] = field( temperature: float = field(
default=0.95, default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."}, metadata={"help": "The value used to modulate the next token probabilities."},
) )
top_p: Optional[float] = field( top_p: float = field(
default=0.7, default=0.7,
metadata={ metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
}, },
) )
top_k: Optional[int] = field( top_k: int = field(
default=50, default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
) )
num_beams: Optional[int] = field( num_beams: int = field(
default=1, default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}, metadata={"help": "Number of beams for beam search. 1 means no beam search."},
) )
max_length: Optional[int] = field( max_length: int = field(
default=512, default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
) )
max_new_tokens: Optional[int] = field( max_new_tokens: int = field(
default=512, default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
) )
repetition_penalty: Optional[float] = field( repetition_penalty: float = field(
default=1.0, default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
) )
length_penalty: Optional[float] = field( length_penalty: float = field(
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )

View File

@ -85,6 +85,10 @@ class ModelArguments:
default="hf", default="hf",
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},
) )
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum input length of the vLLM engine."},
)
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}, metadata={"help": "Auth token to log in with Hugging Face Hub."},

View File

@ -9,8 +9,8 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from ..extras.misc import check_dependencies from ..extras.misc import check_dependencies
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@ -59,6 +59,9 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
@ -69,8 +72,18 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": if model_args.infer_backend == "vllm":
raise ValueError("Adapter is only valid for the LoRA method.") if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.adapter_name_or_path is not None:
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support quantization.")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:

View File

@ -1,11 +1,10 @@
from .loader import load_model, load_model_and_tokenizer, load_tokenizer from .loader import load_model, load_model_and_tokenizer, load_tokenizer
from .utils import dispatch_model, load_valuehead_params from .utils import load_valuehead_params
__all__ = [ __all__ = [
"load_model", "load_model",
"load_model_and_tokenizer", "load_model_and_tokenizer",
"load_tokenizer", "load_tokenizer",
"dispatch_model",
"load_valuehead_params", "load_valuehead_params",
] ]

View File

@ -1,4 +1,3 @@
import inspect
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, Dict, List
import torch import torch
@ -7,7 +6,6 @@ from transformers.utils import cached_file
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
@ -19,36 +17,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
"""
if getattr(model, "quantization_method", None): # already set on current device
return model
if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and model._no_split_modules is not None
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)
else:
return model.to(device=get_current_device())
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r""" r"""
Finds all available modules to apply lora. Finds all available modules to apply lora.

View File

@ -1,4 +1,5 @@
import json import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr import gradio as gr
@ -7,12 +8,12 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir from .common import get_save_dir
from .locales import ALERTS from .locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
from ..chat import BaseEngine
from .manager import Manager from .manager import Manager
@ -22,29 +23,19 @@ class WebChatModel(ChatModel):
) -> None: ) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.model = None self.engine: Optional["BaseEngine"] = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init: # read arguments from command line if not lazy_init: # read arguments from command line
super().__init__() super().__init__()
if demo_mode: # load demo_config.json if exists if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
import json model_name_or_path = os.environ.get("DEMO_MODEL")
template = os.environ.get("DEMO_TEMPLATE")
try: super().__init__(dict(model_name_or_path=model_name_or_path, template=template))
with open("demo_config.json", "r", encoding="utf-8") as f:
args = json.load(f)
assert args.get("model_name_or_path", None) and args.get("template", None)
super().__init__(args)
except AssertionError:
print("Please provided model name and template in `demo_config.json`.")
except Exception:
print("Cannot find `demo_config.json` at current directory.")
@property @property
def loaded(self) -> bool: def loaded(self) -> bool:
return self.model is not None return self.engine is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem_by_name(name)] get = lambda name: data[self.manager.get_elem_by_name(name)]
@ -98,8 +89,7 @@ class WebChatModel(ChatModel):
return return
yield ALERTS["info_unloading"][lang] yield ALERTS["info_unloading"][lang]
self.model = None self.engine = None
self.tokenizer = None
torch_gc() torch_gc()
yield ALERTS["info_unloaded"][lang] yield ALERTS["info_unloaded"][lang]
@ -123,7 +113,7 @@ class WebChatModel(ChatModel):
): ):
response += new_text response += new_text
if tools: if tools:
result = self.template.format_tools.extract(response) result = self.engine.template.format_tools.extract(response)
else: else:
result = response result = response

View File

@ -28,10 +28,9 @@ def create_chat_box(
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1): with gr.Column(scale=1):
gen_kwargs = engine.chatter.generating_args max_new_tokens = gr.Slider(8, 4096, value=512, step=1)
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1) top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01)
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
clear_btn = gr.Button() clear_btn = gr.Button()
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")]) tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])

View File

@ -16,8 +16,8 @@ class Engine:
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.pure_chat = pure_chat self.pure_chat = pure_chat
self.manager = Manager() self.manager = Manager()
self.runner = Runner(self.manager, demo_mode=demo_mode) self.runner = Runner(self.manager, demo_mode)
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat)) self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()} return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}

30
tests/test_throughput.py Normal file
View File

@ -0,0 +1,30 @@
import os
import time
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
client = OpenAI(
api_key="0",
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
messages = [{"role": "user", "content": "Write a long essay about environment protection as long as possible."}]
num_tokens = 0
start_time = time.time()
for _ in range(8):
result = client.chat.completions.create(messages=messages, model="test")
num_tokens += result.usage.completion_tokens
elapsed_time = time.time() - start_time
print("Throughput: {:.2f} tokens/s".format(num_tokens / elapsed_time))
# --infer_backend hf: 27.22 tokens/s (1.0x)
# --infer_backend vllm: 73.03 tokens/s (2.7x)
if __name__ == "__main__":
main()

View File

@ -1,4 +1,5 @@
import json import json
import os
from typing import Sequence from typing import Sequence
from openai import OpenAI from openai import OpenAI
@ -17,13 +18,10 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
return total_score / total_hour return total_score / total_hour
tool_map = {"calculate_gpa": calculate_gpa} def main():
if __name__ == "__main__":
client = OpenAI( client = OpenAI(
api_key="0", api_key="0",
base_url="http://localhost:8000/v1", base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
) )
tools = [ tools = [
{ {
@ -42,6 +40,8 @@ if __name__ == "__main__":
}, },
} }
] ]
tool_map = {"calculate_gpa": calculate_gpa}
messages = [] messages = []
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."}) messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
result = client.chat.completions.create(messages=messages, model="test", tools=tools) result = client.chat.completions.create(messages=messages, model="test", tools=tools)
@ -55,3 +55,7 @@ if __name__ == "__main__":
result = client.chat.completions.create(messages=messages, model="test", tools=tools) result = client.chat.completions.create(messages=messages, model="test", tools=tools)
print(result.choices[0].message.content) print(result.choices[0].message.content)
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665. # Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
if __name__ == "__main__":
main()