This commit is contained in:
hiyouga 2023-08-01 18:13:03 +08:00
parent b9cdff41bb
commit e6a3894b99
2 changed files with 24 additions and 10 deletions

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template
from llmtuner.tuner import load_model_and_tokenizer from llmtuner.tuner import load_model_and_tokenizer
@ -21,15 +21,7 @@ class ChatModel:
generating_args: "GeneratingArguments" generating_args: "GeneratingArguments"
) -> None: ) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
self.model = dispatch_model(self.model, device_map)
else:
self.model = self.model.cuda()
self.template = get_template(data_args.template) self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.generating_args = generating_args self.generating_args = generating_args

View File

@ -117,3 +117,25 @@ def torch_gc() -> None:
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()