This commit is contained in:
hiyouga 2023-09-15 20:58:28 +08:00
parent 8632bff811
commit 0b5f970c05
1 changed files with 16 additions and 10 deletions

View File

@ -144,11 +144,15 @@ def get_train_args(
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if model_args.quantization_bit is not None:
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
if not finetuning_args.resume_lora_training:
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -196,8 +200,10 @@ def get_train_args(
if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
else:
elif training_args.fp16:
model_args.compute_dtype = torch.float16
else:
model_args.compute_dtype = torch.float32
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
@ -231,10 +237,10 @@ def get_infer_args(
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
return model_args, data_args, finetuning_args, generating_args