fix #1097
This commit is contained in:
parent
f9769cff8a
commit
a683c5b797
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
@ -8,7 +7,6 @@ from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
get_peft_model
|
get_peft_model
|
||||||
)
|
)
|
||||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||||
|
@ -63,11 +61,6 @@ def init_adapter(
|
||||||
latest_checkpoint = None
|
latest_checkpoint = None
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
|
||||||
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
|
||||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
|
||||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
|
||||||
|
|
||||||
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
|
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
|
||||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue