fix dispatch

This commit is contained in:
hiyouga 2024-01-03 16:33:16 +08:00
parent 24d8d6f224
commit 1696698eb9
2 changed files with 7 additions and 4 deletions

View File

@ -276,7 +276,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None))
setattr(model, "dtype", getattr(model.pretrained_model, "dtype", None))
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))

View File

@ -1,6 +1,7 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List
from transformers import PreTrainedModel
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
@ -8,7 +9,7 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
@ -23,7 +24,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
if getattr(model, "quantization_method", None): # already set on current device
return model
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and getattr(model.config, "model_type", None) != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory