fix #4677
This commit is contained in:
parent
636bb9c1e6
commit
1e27e8c776
|
@ -13,13 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
|
from ..model.model_utils.quantization import QuantizationMethod
|
||||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
@ -53,6 +54,11 @@ class VllmEngine(BaseEngine):
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
config = load_config(model_args) # may download model from ms hub
|
config = load_config(model_args) # may download model from ms hub
|
||||||
|
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||||
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||||
|
model_args.infer_dtype = "float16"
|
||||||
|
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
|
Loading…
Reference in New Issue