diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 8a067754..e193704a 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Opt from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger -from ..extras.misc import get_device_count, infer_optim_dtype +from ..extras.misc import get_device_count from ..extras.packages import is_vllm_available from ..model import load_config, load_tokenizer from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM @@ -35,8 +35,6 @@ class VllmEngine(BaseEngine): generating_args: "GeneratingArguments", ) -> None: config = load_config(model_args) # may download model from ms hub - infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - infer_dtype = str(infer_dtype).split(".")[-1] self.can_generate = finetuning_args.stage == "sft" tokenizer_module = load_tokenizer(model_args) @@ -50,7 +48,7 @@ class VllmEngine(BaseEngine): "model": model_args.model_name_or_path, "trust_remote_code": True, "download_dir": model_args.cache_dir, - "dtype": infer_dtype, + "dtype": model_args.vllm_dtype, "max_model_len": model_args.vllm_maxlen, "tensor_parallel_size": get_device_count() or 1, "gpu_memory_utilization": model_args.vllm_gpu_util, @@ -70,7 +68,6 @@ class VllmEngine(BaseEngine): engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_feature_size"] = self.image_feature_size if getattr(config, "is_yi_vl_derived_model", None): - # bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828 import vllm.model_executor.models.llava logger.info("Detected Yi-VL model, applying projector patch.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index a3b5b2a6..0434f426 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -125,6 +125,10 @@ class ModelArguments: default=8, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, ) + vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations in the vLLM engine."}, + ) offload_folder: str = field( default="offload", metadata={"help": "Path to offload model weights."},