diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index e924ef6e..67a19b68 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -2,9 +2,9 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from ..data import get_template_and_fix_tokenizer -from ..extras.misc import get_device_count +from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.packages import is_vllm_available -from ..model import load_tokenizer +from ..model import load_config, load_tokenizer from .base_engine import BaseEngine, Response @@ -23,10 +23,20 @@ class VllmEngine(BaseEngine): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: + config = load_config(model_args) # may download model from ms hub + load_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + self.can_generate = finetuning_args.stage == "sft" + self.tokenizer = load_tokenizer(model_args) + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.generating_args = generating_args.to_dict() + engine_args = AsyncEngineArgs( model=model_args.model_name_or_path, trust_remote_code=True, + download_dir=model_args.cache_dir, + dtype=str(load_dtype).split(".")[-1], max_model_len=model_args.vllm_maxlen, tensor_parallel_size=get_device_count() or 1, gpu_memory_utilization=model_args.vllm_gpu_util, @@ -35,10 +45,6 @@ class VllmEngine(BaseEngine): enforce_eager=model_args.vllm_enforce_eager, ) self.model = AsyncLLMEngine.from_engine_args(engine_args) - self.tokenizer = load_tokenizer(model_args) - self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) - self.generating_args = generating_args.to_dict() async def _generate( self, diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 1eaf4271..e0b1c9cd 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,8 +1,9 @@ -from .loader import load_model, load_tokenizer +from .loader import load_config, load_model, load_tokenizer from .utils import find_all_linear_modules, load_valuehead_params __all__ = [ + "load_config", "load_model", "load_tokenizer", "load_valuehead_params", diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 4935dd52..57f5a763 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -12,7 +12,7 @@ from .utils import load_valuehead_params, register_autoclass if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from ..hparams import FinetuningArguments, ModelArguments @@ -21,6 +21,11 @@ logger = get_logger(__name__) def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + r""" + Gets arguments to load config/tokenizer/model. + + Note: including inplace operation of model_args. + """ model_args.model_name_or_path = try_download_model_from_ms(model_args) return { "trust_remote_code": True, @@ -32,9 +37,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": r""" - Loads pretrained tokenizer. Must before load_model. - - Note: including inplace operation of model_args. + Loads pretrained tokenizer. """ init_kwargs = _get_init_kwargs(model_args) try: @@ -57,6 +60,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": return tokenizer +def load_config(model_args: "ModelArguments") -> "PretrainedConfig": + r""" + Loads model config. + """ + init_kwargs = _get_init_kwargs(model_args) + return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + + def load_model( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", @@ -65,10 +76,10 @@ def load_model( add_valuehead: bool = False, ) -> "PreTrainedModel": r""" - Loads pretrained model. Must after load_tokenizer. + Loads pretrained model. """ init_kwargs = _get_init_kwargs(model_args) - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + config = load_config(model_args) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None