diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 773dfc4e..626e07e4 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from threading import Thread from transformers import TextIteratorStreamer -from llmtuner.extras.misc import get_logits_processor +from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.template import get_template from llmtuner.tuner import load_model_and_tokenizer @@ -21,15 +21,7 @@ class ChatModel: generating_args: "GeneratingArguments" ) -> None: self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - - if torch.cuda.device_count() > 1: - from accelerate import dispatch_model - from accelerate.utils import infer_auto_device_map, get_balanced_memory - device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model)) - self.model = dispatch_model(self.model, device_map) - else: - self.model = self.model.cuda() - + self.model = dispatch_model(self.model) self.template = get_template(data_args.template) self.source_prefix = data_args.source_prefix self.generating_args = generating_args diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 82e695dc..93b65aa6 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -117,3 +117,25 @@ def torch_gc() -> None: if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() + + +def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": + r""" + Dispatches a pre-trained model to GPUs with balanced memory. + Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + """ + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + from accelerate.utils import infer_auto_device_map, get_balanced_memory + + if model._no_split_modules is None: + raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") + + kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} + max_memory = get_balanced_memory(model, **kwargs) + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) + return dispatch_model(model, device_map) + else: + return model.cuda()