forked from p04798526/LLaMA-Factory-Mirror
fix #1715
This commit is contained in:
parent
438dea679b
commit
c9b166615c
|
@ -68,6 +68,20 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_device() -> torch.device:
|
||||||
|
import accelerate
|
||||||
|
if accelerate.utils.is_xpu_available():
|
||||||
|
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif accelerate.utils.is_npu_available():
|
||||||
|
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
r"""
|
r"""
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -23,7 +23,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms
|
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
@ -151,7 +151,7 @@ def load_model_and_tokenizer(
|
||||||
if getattr(config, "quantization_config", None):
|
if getattr(config, "quantization_config", None):
|
||||||
if model_args.quantization_bit is not None: # remove bnb quantization
|
if model_args.quantization_bit is not None: # remove bnb quantization
|
||||||
model_args.quantization_bit = None
|
model_args.quantization_bit = None
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
quantization_config = getattr(config, "quantization_config", None)
|
quantization_config = getattr(config, "quantization_config", None)
|
||||||
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ def load_model_and_tokenizer(
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
)
|
)
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
# Load pre-trained models (without valuehead)
|
# Load pre-trained models (without valuehead)
|
||||||
|
@ -209,7 +209,8 @@ def load_model_and_tokenizer(
|
||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
||||||
vhead_path = (
|
vhead_path = (
|
||||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||||
|
|
Loading…
Reference in New Issue