diff --git a/Makefile b/Makefile index 5c754167..3a4a12c9 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: quality style -check_dirs := scripts src +check_dirs := scripts src tests quality: ruff check $(check_dirs) diff --git a/README.md b/README.md index 0e1bec64..ae9d28b6 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,11 @@ Choose your path: - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO. -- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. -- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning. -- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA. +- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. +- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ and Agent tuning. +- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. +- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. ## Benchmark @@ -69,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) + [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. [24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage. @@ -79,7 +82,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`. -[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. +[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). @@ -553,7 +556,7 @@ deepspeed --num_gpus 8 src/train_bash.py \ ### Merge LoRA weights and export model ```bash -python src/export_model.py \ +CUDA_VISIBLE_DEVICES=0 python src/export_model.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ @@ -574,7 +577,7 @@ python src/export_model.py \ ### Inference with OpenAI-style API ```bash -python src/api_demo.py \ +CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ @@ -587,7 +590,7 @@ python src/api_demo.py \ ### Inference with command line ```bash -python src/cli_demo.py \ +CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ @@ -597,7 +600,7 @@ python src/cli_demo.py \ ### Inference with web browser ```bash -python src/web_demo.py \ +CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ diff --git a/README_zh.md b/README_zh.md index bf1b8c9d..79e2b67d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -51,6 +51,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - **先进算法**:DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 +- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 ## 性能指标 @@ -69,17 +70,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 +[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。) + [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 [24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py`。 -[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。 -
展开日志 +[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。 + [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。 -[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 +[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 @@ -552,7 +555,7 @@ deepspeed --num_gpus 8 src/train_bash.py \ ### 合并 LoRA 权重并导出模型 ```bash -python src/export_model.py \ +CUDA_VISIBLE_DEVICES=0 python src/export_model.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ @@ -573,7 +576,7 @@ python src/export_model.py \ ### 使用 OpenAI 风格 API 推理 ```bash -python src/api_demo.py \ +CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ @@ -586,7 +589,7 @@ python src/api_demo.py \ ### 使用命令行推理 ```bash -python src/cli_demo.py \ +CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ @@ -596,7 +599,7 @@ python src/cli_demo.py \ ### 使用浏览器推理 ```bash -python src/web_demo.py \ +CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \ --model_name_or_path path_to_llama_model \ --adapter_name_or_path path_to_checkpoint \ --template default \ diff --git a/examples/lora_single_gpu/README.md b/examples/lora_single_gpu/README.md index 28d4a4fe..ae0f4722 100644 --- a/examples/lora_single_gpu/README.md +++ b/examples/lora_single_gpu/README.md @@ -1,5 +1,8 @@ Usage: -- `pretrain.sh` -- `sft.sh` -> `reward.sh` -> `ppo.sh` -- `sft.sh` -> `dpo.sh` -> `predict.sh` +- `pretrain.sh`: do pre-train (optional) +- `sft.sh`: do supervised fine-tune +- `reward.sh`: do reward modeling (must after sft.sh) +- `ppo.sh`: do PPO training (must after sft.sh and reward.sh) +- `dpo.sh`: do DPO training (must after sft.sh) +- `predict.sh`: do predict (must after sft.sh and dpo.sh) diff --git a/examples/merge_lora/README.md b/examples/merge_lora/README.md index a396767b..c6c16071 100644 --- a/examples/merge_lora/README.md +++ b/examples/merge_lora/README.md @@ -1,3 +1,4 @@ Usage: -- `merge.sh` -> `quantize.sh` +- `merge.sh`: merge the lora weights +- `quantize.sh`: quantize the model with AutoGPTQ (must after merge.sh, optional) diff --git a/examples/merge_lora/merge.sh b/examples/merge_lora/merge.sh index dcad7e6f..42b9fcdd 100644 --- a/examples/merge_lora/merge.sh +++ b/examples/merge_lora/merge.sh @@ -1,6 +1,6 @@ #!/bin/bash -python ../../src/export_model.py \ +CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ --model_name_or_path meta-llama/Llama-2-7b-hf \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --template default \ diff --git a/examples/merge_lora/quantize.sh b/examples/merge_lora/quantize.sh index f53afbb9..143bce50 100644 --- a/examples/merge_lora/quantize.sh +++ b/examples/merge_lora/quantize.sh @@ -1,6 +1,6 @@ #!/bin/bash -python ../../src/export_model.py \ +CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ --model_name_or_path ../../models/llama2-7b-sft \ --template default \ --export_dir ../../models/llama2-7b-sft-int4 \ diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index b47734e0..c5a18bc7 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,4 +1,3 @@ -import asyncio import json import os from contextlib import asynccontextmanager @@ -73,7 +72,6 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": allow_headers=["*"], ) - semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) role_mapping = { Role.USER: DataRole.USER.value, Role.ASSISTANT: DataRole.ASSISTANT.value, @@ -89,7 +87,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) async def create_chat_completion(request: ChatCompletionRequest): - if not chat_model.can_generate: + if not chat_model.engine.can_generate: raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") if len(request.messages) == 0: @@ -121,20 +119,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": else: tools = "" - async with semaphore: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request) - - def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest): if request.stream: if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") - generate = stream_chat_completion(messages, system, tools, request) + generate = stream_chat_completion(input_messages, system, tools, request) return EventSourceResponse(generate, media_type="text/event-stream") - responses = chat_model.chat( - messages, + responses = await chat_model.achat( + input_messages, system, tools, do_sample=request.do_sample, @@ -148,7 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": choices = [] for i, response in enumerate(responses): if tools: - result = chat_model.template.format_tools.extract(response.response_text) + result = chat_model.engine.template.format_tools.extract(response.response_text) else: result = response.response_text @@ -177,7 +170,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) - def stream_chat_completion( + async def stream_chat_completion( messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest ): choice_data = ChatCompletionResponseStreamChoice( @@ -186,7 +179,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield jsonify(chunk) - for new_text in chat_model.stream_chat( + async for new_token in chat_model.astream_chat( messages, system, tools, @@ -195,11 +188,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": top_p=request.top_p, max_new_tokens=request.max_tokens, ): - if len(new_text) == 0: + if len(new_token) == 0: continue choice_data = ChatCompletionResponseStreamChoice( - index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None + index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield jsonify(chunk) @@ -213,18 +206,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": @app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK) async def create_score_evaluation(request: ScoreEvaluationRequest): - if chat_model.can_generate: + if chat_model.engine.can_generate: raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") - async with semaphore: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, get_score, request) - - def get_score(request: ScoreEvaluationRequest): - scores = chat_model.get_scores(request.messages, max_length=request.max_length) + scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) return ScoreEvaluationResponse(model=request.model, scores=scores) return app diff --git a/src/llmtuner/chat/__init__.py b/src/llmtuner/chat/__init__.py index 702d0ac7..a1a79de6 100644 --- a/src/llmtuner/chat/__init__.py +++ b/src/llmtuner/chat/__init__.py @@ -1,4 +1,5 @@ +from .base_engine import BaseEngine from .chat_model import ChatModel -__all__ = ["ChatModel"] +__all__ = ["BaseEngine", "ChatModel"] diff --git a/src/llmtuner/chat/base_engine.py b/src/llmtuner/chat/base_engine.py new file mode 100644 index 00000000..ea46bfba --- /dev/null +++ b/src/llmtuner/chat/base_engine.py @@ -0,0 +1,64 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + from ..data import Template + from ..extras.packages import is_vllm_available + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + if is_vllm_available(): + from vllm import AsyncLLMEngine + + +@dataclass +class Response: + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] + + +class BaseEngine(ABC): + model: Union["PreTrainedModel", "AsyncLLMEngine"] + tokenizer: "PreTrainedTokenizer" + can_generate: bool + template: "Template" + generating_args: Dict[str, Any] + + @abstractmethod + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: ... + + @abstractmethod + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: ... + + @abstractmethod + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: ... + + @abstractmethod + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: ... diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 0d9c6395..bc52fe67 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -1,124 +1,50 @@ -from dataclasses import dataclass -from threading import Thread -from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple +import asyncio +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence -import torch -from transformers import GenerationConfig, TextIteratorStreamer - -from ..data import get_template_and_fix_tokenizer -from ..extras.misc import get_logits_processor from ..hparams import get_infer_args -from ..model import dispatch_model, load_model_and_tokenizer +from .hf_engine import HuggingfaceEngine +from .vllm_engine import VllmEngine -@dataclass -class Response: - response_text: str - response_length: int - prompt_length: int - finish_reason: Literal["stop", "length"] +if TYPE_CHECKING: + from .base_engine import BaseEngine, Response class ChatModel: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: - model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) - self.can_generate = finetuning_args.stage == "sft" - self.model, self.tokenizer = load_model_and_tokenizer( - model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) - ) - self.tokenizer.padding_side = "left" if self.can_generate else "right" - self.model = dispatch_model(self.model) - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + model_args, data_args, finetuning_args, generating_args = get_infer_args(args) + if model_args.infer_backend == "hf": + self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) + elif model_args.infer_backend == "vllm": + self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) + else: + raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) - def _process_args( - self, - messages: Sequence[Dict[str, str]], - system: Optional[str] = None, - tools: Optional[str] = None, - **input_kwargs, - ) -> Tuple[Dict[str, Any], int]: - paired_messages = messages + [{"role": "assistant", "content": ""}] - prompt, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools - ) - prompt_length = len(prompt) - input_ids = torch.tensor([prompt], device=self.model.device) + def _get_event_loop(): + try: + return asyncio.get_running_loop() + except RuntimeError: + return asyncio.new_event_loop() - do_sample = input_kwargs.pop("do_sample", None) - temperature = input_kwargs.pop("temperature", None) - top_p = input_kwargs.pop("top_p", None) - top_k = input_kwargs.pop("top_k", None) - num_return_sequences = input_kwargs.pop("num_return_sequences", None) - repetition_penalty = input_kwargs.pop("repetition_penalty", None) - max_length = input_kwargs.pop("max_length", None) - max_new_tokens = input_kwargs.pop("max_new_tokens", None) - - generating_args = self.generating_args.to_dict() - generating_args.update( - dict( - do_sample=do_sample if do_sample is not None else generating_args["do_sample"], - temperature=temperature or generating_args["temperature"], - top_p=top_p or generating_args["top_p"], - top_k=top_k or generating_args["top_k"], - num_return_sequences=num_return_sequences or 1, - repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], - eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, - pad_token_id=self.tokenizer.pad_token_id, - ) - ) - - if isinstance(num_return_sequences, int) and num_return_sequences > 1: - generating_args["do_sample"] = True - - if max_length: - generating_args.pop("max_new_tokens", None) - generating_args["max_length"] = max_length - - if max_new_tokens: - generating_args.pop("max_length", None) - generating_args["max_new_tokens"] = max_new_tokens - - gen_kwargs = dict( - inputs=input_ids, - generation_config=GenerationConfig(**generating_args), - logits_processor=get_logits_processor(), - ) - - return gen_kwargs, prompt_length - - @torch.inference_mode() def chat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, **input_kwargs, - ) -> List[Response]: - if not self.can_generate: - raise ValueError("The current model does not support `chat`.") + ) -> List["Response"]: + loop = self._get_event_loop() + return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs)) - gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs) - generate_output = self.model.generate(**gen_kwargs) - response_ids = generate_output[:, prompt_length:] - response = self.tokenizer.batch_decode( - response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - results = [] - for i in range(len(response)): - eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() - response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) - results.append( - Response( - response_text=response[i], - response_length=response_length, - prompt_length=prompt_length, - finish_reason="stop" if len(eos_index) else "length", - ) - ) + async def achat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + return await self.engine.chat(messages, system, tools, **input_kwargs) - return results - - @torch.inference_mode() def stream_chat( self, messages: Sequence[Dict[str, str]], @@ -126,44 +52,35 @@ class ChatModel: tools: Optional[str] = None, **input_kwargs, ) -> Generator[str, None, None]: - if not self.can_generate: - raise ValueError("The current model does not support `stream_chat`.") + loop = self._get_event_loop() + generator = self.astream_chat(messages, system, tools, **input_kwargs) + while True: + try: + yield loop.run_until_complete(generator.__anext__()) + except StopAsyncIteration: + break - gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs) - streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - gen_kwargs["streamer"] = streamer + async def astream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs): + yield new_token - thread = Thread(target=self.model.generate, kwargs=gen_kwargs) - thread.start() + def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + loop = self._get_event_loop() + return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs)) - yield from streamer - - @torch.inference_mode() - def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]: - if self.can_generate: - raise ValueError("Cannot get scores using an auto-regressive model.") - - max_length = input_kwargs.pop("max_length", None) - device = getattr(self.model.pretrained_model, "device", "cuda") - inputs = self.tokenizer( - batch_input, - padding=True, - truncation=True, - max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), - return_tensors="pt", - add_special_tokens=True, - ).to(device) - - input_ids: torch.Tensor = inputs["input_ids"] - _, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True) - - if getattr(self.model.config, "model_type", None) == "chatglm": - values = torch.transpose(values, 0, 1) - - scores = [] - for i in range(input_ids.size(0)): - end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero() - end_index = end_indexes[-1].item() if len(end_indexes) else 0 - scores.append(values[i, end_index].nan_to_num().item()) - - return scores + async def aget_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + return await self.engine.get_scores(batch_input, **input_kwargs) diff --git a/src/llmtuner/chat/hf_engine.py b/src/llmtuner/chat/hf_engine.py new file mode 100644 index 00000000..9d8220a7 --- /dev/null +++ b/src/llmtuner/chat/hf_engine.py @@ -0,0 +1,261 @@ +import asyncio +import concurrent.futures +import os +from threading import Thread +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple + +import torch +from transformers import GenerationConfig, TextIteratorStreamer + +from ..data import get_template_and_fix_tokenizer +from ..extras.misc import get_logits_processor +from ..model import load_model_and_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + from trl import PreTrainedModelWrapper + + from ..data import Template + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +class HuggingfaceEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.can_generate = finetuning_args.stage == "sft" + self.model, self.tokenizer = load_model_and_tokenizer( + model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.generating_args = generating_args.to_dict() + self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) + + @staticmethod + def _process_args( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Tuple[Dict[str, Any], int]: + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = template.encode_oneturn( + tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools + ) + prompt_length = len(prompt_ids) + inputs = torch.tensor([prompt_ids], device=model.device) + + do_sample = input_kwargs.pop("do_sample", None) + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + num_return_sequences = input_kwargs.pop("num_return_sequences", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + generating_args.update( + dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + num_return_sequences=num_return_sequences or 1, + repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, + pad_token_id=tokenizer.pad_token_id, + ) + ) + + if isinstance(num_return_sequences, int) and num_return_sequences > 1: + generating_args["do_sample"] = True + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=inputs, + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor(), + ) + + return gen_kwargs, prompt_length + + @staticmethod + @torch.inference_mode() + def _chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> List["Response"]: + gen_kwargs, prompt_length = HuggingfaceEngine._process_args( + model, tokenizer, template, generating_args, messages, system, tools, input_kwargs + ) + generate_output = model.generate(**gen_kwargs) + response_ids = generate_output[:, prompt_length:] + response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + results = [] + for i in range(len(response)): + eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append( + Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length", + ) + ) + + return results + + @staticmethod + @torch.inference_mode() + def _stream_chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Callable[[], str]: + gen_kwargs, _ = HuggingfaceEngine._process_args( + model, tokenizer, template, generating_args, messages, system, tools, input_kwargs + ) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + + def stream(): + try: + return streamer.__next__() + except StopIteration: + raise StopAsyncIteration() + + return stream + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: List[str], + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> List[float]: + max_length = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=True, + ).to(device) + + input_ids: torch.Tensor = inputs["input_ids"] + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + + if getattr(model.config, "model_type", None) == "chatglm": + values = torch.transpose(values, 0, 1) + + scores = [] + for i in range(input_ids.size(0)): + end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero() + end_index = end_indexes[-1].item() if len(end_indexes) else 0 + scores.append(values[i, end_index].nan_to_num().item()) + + return scores + + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + + loop = asyncio.get_running_loop() + input_args = ( + self.model, + self.tokenizer, + self.template, + self.generating_args, + messages, + system, + tools, + input_kwargs, + ) + async with self._semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + return await loop.run_in_executor(pool, self._chat, *input_args) + + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + + loop = asyncio.get_running_loop() + input_args = ( + self.model, + self.tokenizer, + self.template, + self.generating_args, + messages, + system, + tools, + input_kwargs, + ) + async with self._semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + stream = self._stream_chat(*input_args) + while True: + try: + yield await loop.run_in_executor(pool, stream) + except StopAsyncIteration: + break + + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + + loop = asyncio.get_running_loop() + input_args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self._semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + return await loop.run_in_executor(pool, self._get_scores, *input_args) diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py new file mode 100644 index 00000000..258accb6 --- /dev/null +++ b/src/llmtuner/chat/vllm_engine.py @@ -0,0 +1,144 @@ +import uuid +from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence + +from transformers.utils.versions import require_version + +from ..data import get_template_and_fix_tokenizer +from ..extras.misc import get_device_count +from ..extras.packages import is_vllm_available +from ..model import load_tokenizer +from .base_engine import BaseEngine, Response + + +if is_vllm_available(): + from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams + +if TYPE_CHECKING: + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +class VllmEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3") + self.can_generate = finetuning_args.stage == "sft" + engine_args = AsyncEngineArgs( + model=model_args.model_name_or_path, + trust_remote_code=True, + max_model_len=model_args.vllm_maxlen, + tensor_parallel_size=get_device_count(), + disable_log_stats=True, + disable_log_requests=True, + ) + self.model = AsyncLLMEngine.from_engine_args(engine_args) + self.tokenizer = load_tokenizer(model_args) + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.generating_args = generating_args.to_dict() + + async def _generate( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncIterator["RequestOutput"]: + request_id = "chatcmpl-{}".format(uuid.uuid4().hex) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn( + tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools + ) + prompt_length = len(prompt_ids) + + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + num_return_sequences = input_kwargs.pop("num_return_sequences", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + generating_args = self.generating_args.copy() + generating_args.update( + dict( + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + num_return_sequences=num_return_sequences or 1, + repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + ) + ) + + if max_length: + generating_args["max_new_tokens"] = max_length - prompt_length + + if max_new_tokens: + generating_args["max_new_tokens"] = max_new_tokens + + sampling_params = SamplingParams( + n=generating_args["num_return_sequences"], + repetition_penalty=generating_args["repetition_penalty"], + temperature=generating_args["temperature"], + top_p=generating_args["top_p"], + top_k=generating_args["top_k"], + use_beam_search=generating_args["num_beams"] > 1, + length_penalty=generating_args["length_penalty"], + stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + max_tokens=generating_args["max_new_tokens"], + skip_special_tokens=True, + ) + result_generator = self.model.generate( + prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids + ) + return result_generator + + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, **input_kwargs) + async for request_output in generator: + final_output = request_output + + results = [] + for output in final_output.outputs: + results.append( + Response( + response_text=output.text, + response_length=len(output.token_ids), + prompt_length=len(final_output.prompt_token_ids), + finish_reason=output.finish_reason, + ) + ) + + return results + + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, **input_kwargs) + async for result in generator: + delta_text = result.outputs[0].text[len(generated_text) :] + generated_text = result.outputs[0].text + yield delta_text + + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + raise NotImplementedError("vLLM engine does not support get_scores.") diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py index aee03970..80dbf5ff 100644 --- a/src/llmtuner/data/__init__.py +++ b/src/llmtuner/data/__init__.py @@ -1,6 +1,6 @@ from .loader import get_dataset -from .template import get_template_and_fix_tokenizer, templates +from .template import Template, get_template_and_fix_tokenizer, templates from .utils import Role, split_dataset -__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"] +__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"] diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index a9632931..0cd3d6c1 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -2,7 +2,7 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] @@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: @dataclass class Formatter(ABC): slots: SLOTS = field(default_factory=list) - tool_format: Literal["default"] = "default" + tool_format: Optional[Literal["default"]] = None @abstractmethod def apply(self, **kwargs) -> SLOTS: ... @@ -83,12 +83,30 @@ class Formatter(ABC): @dataclass class EmptyFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if has_placeholder: + raise ValueError("Empty formatter should not contain any placeholder.") + def apply(self, **kwargs) -> SLOTS: return self.slots @dataclass class StringFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if not has_placeholder: + raise ValueError("A placeholder is required in the string formatter.") + def apply(self, **kwargs) -> SLOTS: elements = [] for slot in self.slots: @@ -109,6 +127,17 @@ class StringFormatter(Formatter): @dataclass class FunctionFormatter(Formatter): + def __post_init__(self): + has_name, has_args = False, False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if "{{name}}" in slot: + has_name = True + if "{{arguments}}" in slot: + has_args = True + + if not has_name or not has_args: + raise ValueError("Name and arguments placeholders are required in the function formatter.") + def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") try: @@ -133,6 +162,10 @@ class FunctionFormatter(Formatter): @dataclass class ToolFormatter(Formatter): + def __post_init__(self): + if self.tool_format is None: + raise ValueError("Tool format was not found.") + def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") try: diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 9db10b04..f51369bc 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -44,7 +44,7 @@ def load_single_dataset( elif dataset_attr.load_from == "file": data_files = [] - local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) if os.path.isdir(local_path): # is directory for file_name in os.listdir(local_path): data_files.append(os.path.join(local_path, file_name)) diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 2aea5842..861396a0 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -19,13 +19,13 @@ class DatasetAttr: """ basic configs """ load_from: Literal["hf_hub", "ms_hub", "script", "file"] - dataset_name: Optional[str] = None + dataset_name: str """ extra configs """ file_sha1: Optional[str] = None subset: Optional[str] = None folder: Optional[str] = None - ranking: Optional[bool] = False - formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" + ranking: bool = False + formatting: Literal["alpaca", "sharegpt"] = "alpaca" """ columns """ system: Optional[str] = None """ columns for the alpaca format """ diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 9d93487f..21d4b4c6 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -157,6 +157,12 @@ def get_current_device() -> torch.device: def get_device_count() -> int: + r""" + Gets the number of available GPU devices. + """ + if not torch.cuda.is_available(): + return 0 + return torch.cuda.device_count() diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index 29323885..1118b343 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -51,3 +51,7 @@ def is_unsloth_available(): def is_uvicorn_available(): return _is_package_available("uvicorn") + + +def is_vllm_available(): + return _is_package_available("vllm") diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 539e5489..bfc7f3e8 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -16,35 +16,35 @@ class DataArguments: default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, ) - dataset_dir: Optional[str] = field( + dataset_dir: str = field( default="data", metadata={"help": "Path to the folder containing the datasets."}, ) - split: Optional[str] = field( + split: str = field( default="train", metadata={"help": "Which dataset split to use for training and evaluation."}, ) - cutoff_len: Optional[int] = field( + cutoff_len: int = field( default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."}, ) - reserved_label_len: Optional[int] = field( + reserved_label_len: int = field( default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, ) - train_on_prompt: Optional[bool] = field( + train_on_prompt: bool = field( default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}, ) - streaming: Optional[bool] = field( + streaming: bool = field( default=False, metadata={"help": "Enable dataset streaming."}, ) - buffer_size: Optional[int] = field( + buffer_size: int = field( default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, ) - mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( + mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field( default="concat", metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, ) @@ -52,13 +52,13 @@ class DataArguments: default=None, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, ) - overwrite_cache: Optional[bool] = field( + overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}, ) preprocessing_num_workers: Optional[int] = field( default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, + metadata={"help": "The number of processes to use for the pre-processing."}, ) max_samples: Optional[int] = field( default=None, @@ -68,23 +68,23 @@ class DataArguments: default=None, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, ) - ignore_pad_token_for_loss: Optional[bool] = field( + ignore_pad_token_for_loss: bool = field( default=True, metadata={ "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." }, ) - val_size: Optional[float] = field( - default=0, + val_size: float = field( + default=0.0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, ) - sft_packing: Optional[bool] = field( + sft_packing: bool = field( default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}, ) cache_path: Optional[str] = field( default=None, - metadata={"help": "Path to save or load the preprocessed datasets."}, + metadata={"help": "Path to save or load the pre-processed datasets."}, ) def __post_init__(self): diff --git a/src/llmtuner/hparams/evaluation_args.py b/src/llmtuner/hparams/evaluation_args.py index 4257f47b..5a05f6f6 100644 --- a/src/llmtuner/hparams/evaluation_args.py +++ b/src/llmtuner/hparams/evaluation_args.py @@ -14,23 +14,23 @@ class EvaluationArguments: task: str = field( metadata={"help": "Name of the evaluation task."}, ) - task_dir: Optional[str] = field( + task_dir: str = field( default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."}, ) - batch_size: Optional[int] = field( + batch_size: int = field( default=4, metadata={"help": "The batch size per GPU for evaluation."}, ) - seed: Optional[int] = field( + seed: int = field( default=42, metadata={"help": "Random seed to be used with data loaders."}, ) - lang: Optional[Literal["en", "zh"]] = field( + lang: Literal["en", "zh"] = field( default="en", metadata={"help": "Language used at evaluation."}, ) - n_shot: Optional[int] = field( + n_shot: int = field( default=5, metadata={"help": "Number of examplars for few-shot learning."}, ) @@ -38,7 +38,7 @@ class EvaluationArguments: default=None, metadata={"help": "Path to save the evaluation results."}, ) - download_mode: Optional[DownloadMode] = field( + download_mode: DownloadMode = field( default=DownloadMode.REUSE_DATASET_IF_EXISTS, metadata={"help": "Download mode used for the evaluation datasets."}, ) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 59ae6948..be950c30 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -22,7 +22,7 @@ class FreezeArguments: Others choices: the same as LLaMA.""" }, ) - num_layer_trainable: Optional[int] = field( + num_layer_trainable: int = field( default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}, ) @@ -44,11 +44,11 @@ class LoraArguments: default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, ) - lora_dropout: Optional[float] = field( + lora_dropout: float = field( default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."}, ) - lora_rank: Optional[int] = field( + lora_rank: int = field( default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, ) @@ -66,18 +66,19 @@ class LoraArguments: Others choices: the same as LLaMA.""" }, ) - lora_bf16_mode: Optional[bool] = field( + lora_bf16_mode: bool = field( default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}, ) - use_rslora: Optional[bool] = field( + use_rslora: bool = field( default=False, metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, ) - use_dora: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."} + use_dora: bool = field( + default=False, + metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, ) - create_new_adapter: Optional[bool] = field( + create_new_adapter: bool = field( default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, ) @@ -89,23 +90,23 @@ class RLHFArguments: Arguments pertaining to the PPO and DPO training. """ - dpo_beta: Optional[float] = field( + dpo_beta: float = field( default=0.1, metadata={"help": "The beta parameter for the DPO loss."}, ) - dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field( + dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field( default="sigmoid", metadata={"help": "The type of DPO loss to use."}, ) - dpo_ftx: Optional[float] = field( - default=0, + dpo_ftx: float = field( + default=0.0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, ) - ppo_buffer_size: Optional[int] = field( + ppo_buffer_size: int = field( default=1, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, ) - ppo_epochs: Optional[int] = field( + ppo_epochs: int = field( default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}, ) @@ -113,15 +114,15 @@ class RLHFArguments: default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}, ) - ppo_score_norm: Optional[bool] = field( + ppo_score_norm: bool = field( default=False, metadata={"help": "Use score normalization in PPO training."}, ) - ppo_target: Optional[float] = field( + ppo_target: float = field( default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}, ) - ppo_whiten_rewards: Optional[bool] = field( + ppo_whiten_rewards: bool = field( default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, ) @@ -149,7 +150,7 @@ class RLHFArguments: default=None, metadata={"help": "The number of bits to quantize the reward model."}, ) - reward_model_type: Optional[Literal["lora", "full", "api"]] = field( + reward_model_type: Literal["lora", "full", "api"] = field( default="lora", metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, ) @@ -161,19 +162,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): Arguments pertaining to which techniques we are going to fine-tuning with. """ - stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) - finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( + finetuning_type: Literal["lora", "freeze", "full"] = field( default="lora", metadata={"help": "Which fine-tuning method to use."}, ) - use_llama_pro: Optional[bool] = field( + use_llama_pro: bool = field( default=False, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, ) - plot_loss: Optional[bool] = field( + plot_loss: bool = field( default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index 06b5dfc3..70dabb3e 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Optional +from typing import Any, Dict @dataclass @@ -8,41 +8,41 @@ class GeneratingArguments: Arguments pertaining to specify the decoding parameters. """ - do_sample: Optional[bool] = field( + do_sample: bool = field( default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, ) - temperature: Optional[float] = field( + temperature: float = field( default=0.95, metadata={"help": "The value used to modulate the next token probabilities."}, ) - top_p: Optional[float] = field( + top_p: float = field( default=0.7, metadata={ "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." }, ) - top_k: Optional[int] = field( + top_k: int = field( default=50, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, ) - num_beams: Optional[int] = field( + num_beams: int = field( default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."}, ) - max_length: Optional[int] = field( + max_length: int = field( default=512, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, ) - max_new_tokens: Optional[int] = field( + max_new_tokens: int = field( default=512, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, ) - repetition_penalty: Optional[float] = field( + repetition_penalty: float = field( default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, ) - length_penalty: Optional[float] = field( + length_penalty: float = field( default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index f3972f66..2798bda7 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -85,6 +85,10 @@ class ModelArguments: default="hf", metadata={"help": "Backend engine used at inference."}, ) + vllm_maxlen: int = field( + default=2048, + metadata={"help": "Maximum input length of the vLLM engine."}, + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index cad08b17..17578558 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -9,8 +9,8 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.trainer_utils import get_last_checkpoint from ..extras.logging import get_logger -from ..extras.packages import is_unsloth_available from ..extras.misc import check_dependencies +from ..extras.packages import is_unsloth_available from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments @@ -59,6 +59,9 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: + if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Adapter is only valid for the LoRA method.") + if model_args.quantization_bit is not None: if finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") @@ -69,8 +72,18 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") - if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Adapter is only valid for the LoRA method.") + if model_args.infer_backend == "vllm": + if finetuning_args.stage != "sft": + raise ValueError("vLLM engine only supports auto-regressive models.") + + if model_args.adapter_name_or_path is not None: + raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.") + + if model_args.quantization_bit is not None: + raise ValueError("vLLM engine does not support quantization.") + + if model_args.rope_scaling is not None: + raise ValueError("vLLM engine does not support RoPE scaling.") def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 933ffc5b..bb7c4db9 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,11 +1,10 @@ from .loader import load_model, load_model_and_tokenizer, load_tokenizer -from .utils import dispatch_model, load_valuehead_params +from .utils import load_valuehead_params __all__ = [ "load_model", "load_model_and_tokenizer", "load_tokenizer", - "dispatch_model", "load_valuehead_params", ] diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index aa8a9a63..4a4ecf2e 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,4 +1,3 @@ -import inspect from typing import TYPE_CHECKING, Dict, List import torch @@ -7,7 +6,6 @@ from transformers.utils import cached_file from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.logging import get_logger -from ..extras.misc import get_current_device if TYPE_CHECKING: @@ -19,36 +17,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": - r""" - Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available. - Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570 - """ - if getattr(model, "quantization_method", None): # already set on current device - return model - - if ( - torch.cuda.device_count() > 1 - and isinstance(model, PreTrainedModel) - and model._no_split_modules is not None - and model.config.model_type != "chatglm" - ): - from accelerate import dispatch_model - from accelerate.utils import get_balanced_memory, infer_auto_device_map - - kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")} - max_memory = get_balanced_memory(model, **kwargs) - # Make sure tied weights are tied before creating the device map. - model.tie_weights() - device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) - device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"} - if "skip_keys" in inspect.signature(dispatch_model).parameters: - device_map_kwargs["skip_keys"] = model._skip_keys_device_placement - return dispatch_model(model, **device_map_kwargs) - else: - return model.to(device=get_current_device()) - - def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: r""" Finds all available modules to apply lora. diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 255f0fd7..b71d434b 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -1,4 +1,5 @@ import json +import os from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple import gradio as gr @@ -7,12 +8,12 @@ from gradio.components import Component # cannot use TYPE_CHECKING here from ..chat import ChatModel from ..data import Role from ..extras.misc import torch_gc -from ..hparams import GeneratingArguments from .common import get_save_dir from .locales import ALERTS if TYPE_CHECKING: + from ..chat import BaseEngine from .manager import Manager @@ -22,29 +23,19 @@ class WebChatModel(ChatModel): ) -> None: self.manager = manager self.demo_mode = demo_mode - self.model = None - self.tokenizer = None - self.generating_args = GeneratingArguments() + self.engine: Optional["BaseEngine"] = None if not lazy_init: # read arguments from command line super().__init__() - if demo_mode: # load demo_config.json if exists - import json - - try: - with open("demo_config.json", "r", encoding="utf-8") as f: - args = json.load(f) - assert args.get("model_name_or_path", None) and args.get("template", None) - super().__init__(args) - except AssertionError: - print("Please provided model name and template in `demo_config.json`.") - except Exception: - print("Cannot find `demo_config.json` at current directory.") + if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model + model_name_or_path = os.environ.get("DEMO_MODEL") + template = os.environ.get("DEMO_TEMPLATE") + super().__init__(dict(model_name_or_path=model_name_or_path, template=template)) @property def loaded(self) -> bool: - return self.model is not None + return self.engine is not None def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: get = lambda name: data[self.manager.get_elem_by_name(name)] @@ -98,8 +89,7 @@ class WebChatModel(ChatModel): return yield ALERTS["info_unloading"][lang] - self.model = None - self.tokenizer = None + self.engine = None torch_gc() yield ALERTS["info_unloaded"][lang] @@ -123,7 +113,7 @@ class WebChatModel(ChatModel): ): response += new_text if tools: - result = self.template.format_tools.extract(response) + result = self.engine.template.format_tools.extract(response) else: result = response diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 7713bda2..bd900c93 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -28,10 +28,9 @@ def create_chat_box( submit_btn = gr.Button(variant="primary") with gr.Column(scale=1): - gen_kwargs = engine.chatter.generating_args - max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1) - top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) - temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) + max_new_tokens = gr.Slider(8, 4096, value=512, step=1) + top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01) + temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) clear_btn = gr.Button() tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")]) diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 78a22a21..fe8a02ae 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -16,8 +16,8 @@ class Engine: self.demo_mode = demo_mode self.pure_chat = pure_chat self.manager = Manager() - self.runner = Runner(self.manager, demo_mode=demo_mode) - self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat)) + self.runner = Runner(self.manager, demo_mode) + self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()} diff --git a/tests/test_throughput.py b/tests/test_throughput.py new file mode 100644 index 00000000..e8048910 --- /dev/null +++ b/tests/test_throughput.py @@ -0,0 +1,30 @@ +import os +import time + +from openai import OpenAI +from transformers.utils.versions import require_version + + +require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") + + +def main(): + client = OpenAI( + api_key="0", + base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), + ) + messages = [{"role": "user", "content": "Write a long essay about environment protection as long as possible."}] + num_tokens = 0 + start_time = time.time() + for _ in range(8): + result = client.chat.completions.create(messages=messages, model="test") + num_tokens += result.usage.completion_tokens + + elapsed_time = time.time() - start_time + print("Throughput: {:.2f} tokens/s".format(num_tokens / elapsed_time)) + # --infer_backend hf: 27.22 tokens/s (1.0x) + # --infer_backend vllm: 73.03 tokens/s (2.7x) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_toolcall.py b/tests/test_toolcall.py similarity index 92% rename from scripts/test_toolcall.py rename to tests/test_toolcall.py index e3351e3e..a54a0053 100644 --- a/scripts/test_toolcall.py +++ b/tests/test_toolcall.py @@ -1,4 +1,5 @@ import json +import os from typing import Sequence from openai import OpenAI @@ -17,13 +18,10 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: return total_score / total_hour -tool_map = {"calculate_gpa": calculate_gpa} - - -if __name__ == "__main__": +def main(): client = OpenAI( api_key="0", - base_url="http://localhost:8000/v1", + base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), ) tools = [ { @@ -42,6 +40,8 @@ if __name__ == "__main__": }, } ] + tool_map = {"calculate_gpa": calculate_gpa} + messages = [] messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."}) result = client.chat.completions.create(messages=messages, model="test", tools=tools) @@ -55,3 +55,7 @@ if __name__ == "__main__": result = client.chat.completions.create(messages=messages, model="test", tools=tools) print(result.choices[0].message.content) # Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665. + + +if __name__ == "__main__": + main()