Fix #294
This commit is contained in:
parent
b9cdff41bb
commit
e6a3894b99
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.tuner import load_model_and_tokenizer
|
from llmtuner.tuner import load_model_and_tokenizer
|
||||||
|
|
||||||
|
@ -21,15 +21,7 @@ class ChatModel:
|
||||||
generating_args: "GeneratingArguments"
|
generating_args: "GeneratingArguments"
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
self.model = dispatch_model(self.model)
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
from accelerate import dispatch_model
|
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
|
||||||
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
|
|
||||||
self.model = dispatch_model(self.model, device_map)
|
|
||||||
else:
|
|
||||||
self.model = self.model.cuda()
|
|
||||||
|
|
||||||
self.template = get_template(data_args.template)
|
self.template = get_template(data_args.template)
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
|
|
|
@ -117,3 +117,25 @@ def torch_gc() -> None:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||||
|
"""
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
|
if model._no_split_modules is None:
|
||||||
|
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||||
|
|
||||||
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||||
|
max_memory = get_balanced_memory(model, **kwargs)
|
||||||
|
# Make sure tied weights are tied before creating the device map.
|
||||||
|
model.tie_weights()
|
||||||
|
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||||
|
return dispatch_model(model, device_map)
|
||||||
|
else:
|
||||||
|
return model.cuda()
|
||||||
|
|
Loading…
Reference in New Issue