fix #3273
This commit is contained in:
parent
d1fb6c72b5
commit
efc345c4b0
|
@ -129,6 +129,10 @@ class ModelArguments:
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
)
|
)
|
||||||
|
export_device: str = field(
|
||||||
|
default="cpu",
|
||||||
|
metadata={"help": "The device used in model export."},
|
||||||
|
)
|
||||||
export_quantization_bit: Optional[int] = field(
|
export_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the exported model."},
|
metadata={"help": "The number of bits to quantize the exported model."},
|
||||||
|
|
|
@ -10,7 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import check_dependencies
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
from ..extras.packages import is_unsloth_available
|
from ..extras.packages import is_unsloth_available
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
|
@ -235,6 +235,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
elif training_args.fp16:
|
elif training_args.fp16:
|
||||||
model_args.compute_dtype = torch.float16
|
model_args.compute_dtype = torch.float16
|
||||||
|
|
||||||
|
model_args.device_map = {"": get_current_device()}
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||||
|
|
||||||
|
@ -278,8 +279,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
if model_args.export_dir is not None:
|
if model_args.export_dir is not None:
|
||||||
model_args.device_map = {"": "cpu"}
|
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||||
model_args.compute_dtype = torch.float32
|
|
||||||
else:
|
else:
|
||||||
model_args.device_map = "auto"
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,9 @@ def init_adapter(
|
||||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||||
|
raise ValueError("You can only use lora for quantized models.")
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
if not finetuning_args.pure_bf16:
|
if not finetuning_args.pure_bf16:
|
||||||
|
@ -129,9 +132,12 @@ def init_adapter(
|
||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
|
if (
|
||||||
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
|
finetuning_args.use_dora
|
||||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
and getattr(model, "quantization_method", None) is not None
|
||||||
|
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
|
||||||
|
):
|
||||||
|
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
|
|
|
@ -323,8 +323,8 @@ def patch_config(
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
if init_kwargs["low_cpu_mem_usage"]:
|
if init_kwargs["low_cpu_mem_usage"]:
|
||||||
if "device_map" not in init_kwargs:
|
if "device_map" not in init_kwargs and model_args.device_map:
|
||||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
init_kwargs["device_map"] = model_args.device_map
|
||||||
|
|
||||||
if init_kwargs["device_map"] == "auto":
|
if init_kwargs["device_map"] == "auto":
|
||||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||||
|
|
Loading…
Reference in New Issue