diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 563b1827..fb2835e8 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -133,7 +133,9 @@ def _configure_quantization( if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") - init_kwargs["device_map"] = {"": get_current_device()} + if model_args.quantization_device_map != "auto": + init_kwargs["device_map"] = {"": get_current_device()} + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") @@ -268,7 +270,6 @@ def _prepare_model_for_training( # According to: https://github.com/huggingface/transformers/issues/28339 model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) - # model.enable_input_require_grads() setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")