parent
5b78e269b6
commit
03d05991f8
|
@ -68,18 +68,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_current_device() -> str:
|
||||
import accelerate
|
||||
if accelerate.utils.is_xpu_available():
|
||||
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif accelerate.utils.is_npu_available():
|
||||
return "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif torch.cuda.is_available():
|
||||
return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
|
@ -22,7 +23,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
|||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
@ -150,7 +151,7 @@ def load_model_and_tokenizer(
|
|||
if getattr(config, "quantization_config", None):
|
||||
if model_args.quantization_bit is not None: # remove bnb quantization
|
||||
model_args.quantization_bit = None
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
|
@ -172,7 +173,7 @@ def load_model_and_tokenizer(
|
|||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load pre-trained models (without valuehead)
|
||||
|
|
Loading…
Reference in New Issue