diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 2c659993..b8424d62 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -68,6 +68,20 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param +def get_current_device() -> torch.device: + import accelerate + if accelerate.utils.is_xpu_available(): + device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif accelerate.utils.is_npu_available(): + device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif torch.cuda.is_available(): + device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) + else: + device = "cpu" + + return torch.device(device) + + def get_logits_processor() -> "LogitsProcessorList": r""" Gets logits processor that removes NaN and Inf logits. diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 1f29abb2..faba1ee2 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -2,7 +2,7 @@ import os import math import torch from types import MethodType -from typing import TYPE_CHECKING, Literal, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from transformers import ( AutoConfig, @@ -23,7 +23,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from transformers.deepspeed import is_deepspeed_zero3_enabled from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms +from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments @@ -151,7 +151,7 @@ def load_model_and_tokenizer( if getattr(config, "quantization_config", None): if model_args.quantization_bit is not None: # remove bnb quantization model_args.quantization_bit = None - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} + config_kwargs["device_map"] = {"": get_current_device()} quantization_config = getattr(config, "quantization_config", None) logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1))) @@ -173,7 +173,7 @@ def load_model_and_tokenizer( bnb_4bit_quant_type=model_args.quantization_type ) - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} + config_kwargs["device_map"] = {"": get_current_device()} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) # Load pre-trained models (without valuehead) @@ -209,7 +209,8 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if add_valuehead: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) - setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name]) + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method vhead_path = ( model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path