diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index e63d9477..5b91cb47 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -276,7 +276,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) - setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None)) - setattr(model, "dtype", getattr(model.pretrained_model, "dtype", None)) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index b98cdca0..374808ee 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,6 +1,7 @@ import torch import inspect from typing import TYPE_CHECKING, Any, Dict, List +from transformers import PreTrainedModel from transformers.utils import cached_file from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME @@ -8,7 +9,7 @@ from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import get_current_device if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer + from transformers import PretrainedConfig, PreTrainedTokenizer from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments @@ -23,7 +24,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": if getattr(model, "quantization_method", None): # already set on current device return model - if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm": + if ( + torch.cuda.device_count() > 1 + and isinstance(model, PreTrainedModel) + and getattr(model.config, "model_type", None) != "chatglm" + ): from accelerate import dispatch_model from accelerate.utils import infer_auto_device_map, get_balanced_memory