This commit is contained in:
hiyouga 2024-06-20 22:56:05 +08:00
parent a459624474
commit 8d4f5093cf
4 changed files with 15 additions and 13 deletions

View File

@ -199,6 +199,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.")
if training_args.deepspeed:
raise ValueError("`pure_bf16` is incompatible with DeepSpeed.")
if training_args.fp16 or training_args.bf16:
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")

View File

@ -289,16 +289,15 @@ def init_adapter(
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and quantization_bit is not None (qlora)
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
# 3. is_trainable and not pure_bf16 and not badam
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
cast_trainable_params_to_fp32 = False
elif model_args.quantization_bit is None and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
):
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True

View File

@ -91,8 +91,8 @@ def patch_config(
# cast data type of the model if:
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
# 2. fsdp + qlora
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()):
# 2. quantization_bit is not None (qlora)
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True

View File

@ -49,10 +49,10 @@ def create_top() -> Dict[str, "Component"]:
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
model_name.input(save_config, inputs=[lang, model_name], queue=False).then(
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
)
model_name.input(save_config, inputs=[lang, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False