From 0b5f970c05c524670d66c810e9f081a52a1fb5e6 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 15 Sep 2023 20:58:28 +0800 Subject: [PATCH] fix #913 --- src/llmtuner/tuner/core/parser.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index e51acf7a..2bb21325 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -144,11 +144,15 @@ def get_train_args( raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora": + if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + + if model_args.quantization_bit is not None: if len(model_args.checkpoint_dir) != 1: - raise ValueError("Only LoRA tuning accepts multiple checkpoints.") - elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: - raise ValueError("Quantized model only accepts a single checkpoint.") + raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") + + if not finetuning_args.resume_lora_training: + raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.") if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -196,8 +200,10 @@ def get_train_args( if not torch.cuda.is_bf16_supported(): raise ValueError("Current device does not support bf16 training.") model_args.compute_dtype = torch.bfloat16 - else: + elif training_args.fp16: model_args.compute_dtype = torch.float16 + else: + model_args.compute_dtype = torch.float32 model_args.model_max_length = data_args.max_source_length + data_args.max_target_length @@ -231,10 +237,10 @@ def get_infer_args( raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora": - if len(model_args.checkpoint_dir) != 1: - raise ValueError("Only LoRA tuning accepts multiple checkpoints.") - elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: - raise ValueError("Quantized model only accepts a single checkpoint.") + if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + + if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") return model_args, data_args, finetuning_args, generating_args