add max_memory for gptq #1923

This commit is contained in:
hiyouga 2023-12-20 18:15:17 +08:00
parent 31165a9822
commit c4a3977ad7
4 changed files with 26 additions and 24 deletions

View File

@ -63,8 +63,8 @@ def get_dataset(
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
@ -75,7 +75,7 @@ def get_dataset(
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")

View File

@ -3,25 +3,22 @@ import os
import torch
from typing import TYPE_CHECKING, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available,
is_torch_xpu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
except:
_is_bf16_available = False
if TYPE_CHECKING:
from transformers import HfArgumentParser
from llmtuner.hparams import ModelArguments
@ -68,12 +65,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> torch.device:
import accelerate
if accelerate.utils.is_xpu_available():
r"""
Gets the current available device.
"""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif accelerate.utils.is_npu_available():
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif torch.cuda.is_available():
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"
@ -117,7 +116,7 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
return
try:
from modelscope import snapshot_download # type: ignore
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path,

View File

@ -76,6 +76,7 @@ def configure_quantization(
if finetuning_args.export_quantization_bit is not None: # gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
@ -86,6 +87,7 @@ def configure_quantization(
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))

View File

@ -8,6 +8,7 @@ from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
@ -20,7 +21,7 @@ logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
"""
if getattr(model, "quantization_method", None): # already set on current device
@ -43,7 +44,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)
else:
return model.cuda()
return model.to(device=get_current_device())
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: