parent
b2f7cb4465
commit
dc1e8b7181
|
@ -97,7 +97,7 @@ def _init_adapter(
|
|||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
lastest_checkpoint = None
|
||||
latest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||
|
@ -106,7 +106,7 @@ def _init_adapter(
|
|||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
|
@ -117,10 +117,10 @@ def _init_adapter(
|
|||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if lastest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
|
||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
|
|
Loading…
Reference in New Issue