fix #913
This commit is contained in:
parent
8632bff811
commit
0b5f970c05
|
@ -144,11 +144,15 @@ def get_train_args(
|
|||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
if not finetuning_args.resume_lora_training:
|
||||
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
@ -196,8 +200,10 @@ def get_train_args(
|
|||
if not torch.cuda.is_bf16_supported():
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
|
@ -231,10 +237,10 @@ def get_infer_args(
|
|||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
if len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
|
Loading…
Reference in New Issue