From f86857bd9ef456e77ad79a584f1fa08a129e5270 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 20 Dec 2023 15:11:15 +0800 Subject: [PATCH] fix mixtral inference #1821 --- src/llmtuner/model/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index e8aa164d..e7596d88 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,5 +1,6 @@ import math import torch +import inspect from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from transformers.utils import cached_file @@ -20,7 +21,7 @@ logger = get_logger(__name__) def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": r""" Dispatches a pre-trained model to GPUs with balanced memory. - Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570 """ if getattr(model, "quantization_method", None): # already set on current device return model @@ -32,12 +33,15 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": if model._no_split_modules is None: raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") - kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} + kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")} max_memory = get_balanced_memory(model, **kwargs) # Make sure tied weights are tied before creating the device map. model.tie_weights() device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) - return dispatch_model(model, device_map) + device_map_kwargs = {"device_map": device_map} + if "skip_keys" in inspect.signature(dispatch_model).parameters: + device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + return dispatch_model(model, **device_map_kwargs) else: return model.cuda()