diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index f0d23676..91ac2949 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -13,13 +13,14 @@ # limitations under the License. import uuid -from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger from ..extras.misc import get_device_count from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5 from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response @@ -53,6 +54,11 @@ class VllmEngine(BaseEngine): generating_args: "GeneratingArguments", ) -> None: config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" self.can_generate = finetuning_args.stage == "sft" tokenizer_module = load_tokenizer(model_args)