diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 11cb4f72..8cdf85bf 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -87,11 +87,11 @@ def load_model_and_tokenizer( model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, + torch_dtype=model_args.compute_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs ) - model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) @@ -113,6 +113,7 @@ def load_model_and_tokenizer( if not is_trainable: model.requires_grad_(False) + model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model model.eval() else: model.train() diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 5b91cb47..788f6e60 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -276,5 +276,6 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) + setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None)) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 3870f56f..b98cdca0 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -27,7 +27,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": from accelerate import dispatch_model from accelerate.utils import infer_auto_device_map, get_balanced_memory - if model._no_split_modules is None: + if getattr(model, "_no_split_modules", None) is None: raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}