This commit is contained in:
hiyouga 2024-01-09 14:49:13 +08:00
parent 3ae735ffe8
commit ebee4f6a2a
1 changed files with 2 additions and 4 deletions

View File

@ -27,14 +27,12 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and getattr(model.config, "model_type", None) != "chatglm"
and model._no_split_modules is not None
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if getattr(model, "_no_split_modules", None) is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.