From 259af60d28985b919911587716c24a3ac7f7de64 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 5 Mar 2024 20:49:50 +0800 Subject: [PATCH] improve aqlm optim --- src/llmtuner/eval/evaluator.py | 3 +-- src/llmtuner/hparams/model_args.py | 3 +++ src/llmtuner/hparams/parser.py | 2 ++ src/llmtuner/model/loader.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/eval/evaluator.py b/src/llmtuner/eval/evaluator.py index 7e8b064a..4969561f 100644 --- a/src/llmtuner/eval/evaluator.py +++ b/src/llmtuner/eval/evaluator.py @@ -14,7 +14,7 @@ from transformers.utils import cached_file from ..data import get_template_and_fix_tokenizer from ..extras.constants import CHOICES, SUBJECTS from ..hparams import get_eval_args -from ..model import dispatch_model, load_model_and_tokenizer +from ..model import load_model_and_tokenizer from .template import get_eval_template @@ -23,7 +23,6 @@ class Evaluator: self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args) self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 - self.model = dispatch_model(self.model) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.eval_template = get_eval_template(self.eval_args.lang) self.choice_inputs = [ diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 52cd973f..573efb21 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -121,6 +121,9 @@ class ModelArguments: default=False, metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) + aqlm_optimization: Optional[bool] = field( + default=False, metadata={"help": "Whether or not to optimize the training performance of AQLM models."} + ) def __post_init__(self): self.compute_dtype = None diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 6b55e03d..8f9d81e3 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -226,6 +226,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) ) model_args.model_max_length = data_args.cutoff_len + model_args.aqlm_optimization = not training_args.predict_with_generate # Log on each process the small summary: logger.info( @@ -262,6 +263,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: _set_transformers_logging() _verify_model_args(model_args, finetuning_args) _check_dependencies(disabled=finetuning_args.disable_version_checking) + model_args.aqlm_optimization = True if data_args.template is None: raise ValueError("Please specify which `template` to use.") diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 45260310..e5b3bdd1 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -88,7 +88,7 @@ def load_model( if model is None: model_init_context = nullcontext() - if is_trainable and getattr(config, "quantization_config", None): + if model_args.aqlm_optimization and getattr(config, "quantization_config", None): quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) if quantization_config.get("quant_method", None) == "aqlm": import aqlm # type: ignore