update_from_finetuned
This commit is contained in:
parent
7d256ce288
commit
617955e08e
|
@ -37,4 +37,5 @@ log.txt
|
||||||
**/Delta_Memory/
|
**/Delta_Memory/
|
||||||
**/output/
|
**/output/
|
||||||
**/thunlp/
|
**/thunlp/
|
||||||
|
**/saved_ckpts/
|
||||||
|
|
||||||
|
|
|
@ -343,7 +343,7 @@ class SuperGLUECB(AbstractTask):
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
|
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
|
||||||
metric_names = ["f1_multiclass", "accuracy"]
|
metric_names = ["accuracy"]
|
||||||
|
|
||||||
verbalizers = {
|
verbalizers = {
|
||||||
"0":{"0": "yes",
|
"0":{"0": "yes",
|
||||||
|
|
|
@ -47,20 +47,20 @@ def spearman_corrcoef(predictions, targets) -> dict:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def spearman_corrcoef(predictions, targets) -> dict:
|
# def spearman_corrcoef(predictions, targets) -> dict:
|
||||||
"""Computes Spearman correlation coefficient."""
|
# """Computes Spearman correlation coefficient."""
|
||||||
# TODO: we need to do postprocessors in a clean way for each dataset.
|
# # TODO: we need to do postprocessors in a clean way for each dataset.
|
||||||
from examples_seq2seq.data_processors.postprocessors import string_to_float
|
# from examples_seq2seq.data_processors.postprocessors import string_to_float
|
||||||
targets = [string_to_float(target) for target in targets]
|
# targets = [string_to_float(target) for target in targets]
|
||||||
predictions= [string_to_float(prediction) for prediction in predictions]
|
# predictions= [string_to_float(prediction) for prediction in predictions]
|
||||||
spearman_corrcoef = 100 * scipy.stats.spearmanr(targets, predictions)[0]
|
# spearman_corrcoef = 100 * scipy.stats.spearmanr(targets, predictions)[0]
|
||||||
|
|
||||||
# Note that if all the predictions will be the same, spearman
|
# # Note that if all the predictions will be the same, spearman
|
||||||
# correlation is nan, to gaurad against this, we check the output
|
# # correlation is nan, to gaurad against this, we check the output
|
||||||
# and return 0 in this case.
|
# # and return 0 in this case.
|
||||||
if math.isnan(spearman_corrcoef):
|
# if math.isnan(spearman_corrcoef):
|
||||||
spearman_corrcoef = 0
|
# spearman_corrcoef = 0
|
||||||
return {"spearmanr": spearman_corrcoef}
|
# return {"spearmanr": spearman_corrcoef}
|
||||||
|
|
||||||
|
|
||||||
def f1_score_with_invalid(predictions, targets) -> dict:
|
def f1_score_with_invalid(predictions, targets) -> dict:
|
||||||
|
|
|
@ -26,6 +26,7 @@ You can also adapt this script on your own tasks.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
||||||
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
@ -56,7 +57,7 @@ from transformers.trainer_utils import is_main_process, get_last_checkpoint
|
||||||
|
|
||||||
from data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
from data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
||||||
from utils import read_json, save_json
|
from utils import read_json, save_json
|
||||||
from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, RemainArgHfArgumentParser
|
from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, DeltaArguments, RemainArgHfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -66,16 +67,14 @@ def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
# If we pass only one argument to the script and it's the path to a json file,
|
|
||||||
# let's parse it to get our arguments.
|
|
||||||
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, delta_args, remain_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
|
||||||
|
|
||||||
|
# You can provide a json file with contains the arguments and use the --argument some_arg to override or append to the json file.
|
||||||
|
json_file, cmd_args = (os.path.abspath(sys.argv[1]), sys.argv[2:]) if sys.argv[1].endswith(".json") else (None, sys.argv[1:])
|
||||||
|
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file_with_cmd_args(json_file=json_file, command_line_args=cmd_args)
|
||||||
|
logger.warning("The following arguments not used! {}".format(remain_args))
|
||||||
|
|
||||||
print(f"{training_args.output_dir}/results.json")
|
logger.info(f"The results will be used in {training_args.output_dir}/results.json")
|
||||||
# exit()
|
# exit()
|
||||||
# Detecting last checkpoint.
|
# Detecting last checkpoint.
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
|
@ -161,7 +160,8 @@ def main():
|
||||||
|
|
||||||
if delta_args.delta_type.lower() != "none":
|
if delta_args.delta_type.lower() != "none":
|
||||||
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
||||||
delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
from dataclasses import asdict
|
||||||
|
delta_config = AutoDeltaConfig.from_dict(asdict(delta_args))
|
||||||
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=model)
|
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=model)
|
||||||
delta_model.freeze_module(set_state_dict = True)
|
delta_model.freeze_module(set_state_dict = True)
|
||||||
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
||||||
|
@ -278,14 +278,9 @@ def main():
|
||||||
|
|
||||||
if torch.cuda.is_available() and training_args.compute_memory:
|
if torch.cuda.is_available() and training_args.compute_memory:
|
||||||
peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
|
peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
|
||||||
print(
|
|
||||||
"Memory utilization",
|
|
||||||
peak_memory,
|
|
||||||
"GB"
|
|
||||||
)
|
|
||||||
performance_metrics.update({"peak_memory": peak_memory})
|
performance_metrics.update({"peak_memory": peak_memory})
|
||||||
if training_args.compute_memory or training_args.compute_time:
|
if training_args.compute_memory or training_args.compute_time:
|
||||||
print("Efficiency Statistics {}".format(performance_metrics))
|
logger.info("Efficiency Statistics {}".format(performance_metrics))
|
||||||
trainer.save_metrics("performance", performance_metrics)
|
trainer.save_metrics("performance", performance_metrics)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
@ -326,10 +321,13 @@ def main():
|
||||||
# all_results['repo_name'] = repo_name
|
# all_results['repo_name'] = repo_name
|
||||||
|
|
||||||
|
|
||||||
delta_model.save_finetuned(push_to_hf=training_args.push_to_hf,
|
delta_model.save_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path,
|
||||||
push_to_dc=training_args.push_to_dc,
|
push_to_dc=training_args.push_to_dc,
|
||||||
center_args={},
|
center_args={"test_performance":all_results['test'][data_args.task_name]['test_average_metrics'],
|
||||||
|
},
|
||||||
center_args_pool = {**vars(model_args), **vars(data_args), **vars(training_args), **vars(delta_args)},
|
center_args_pool = {**vars(model_args), **vars(data_args), **vars(training_args), **vars(delta_args)},
|
||||||
|
list_tags = ['NLI'],
|
||||||
|
dict_tags = {'purpose':'for testing'},
|
||||||
delay_push=True,
|
delay_push=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -63,40 +63,32 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
|
||||||
# or by passing the --help flag to this script.
|
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
||||||
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
|
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
|
||||||
print("here1")
|
|
||||||
if sys.argv[1].endswith(".json"):
|
|
||||||
print("here2", sys.argv[2:])
|
|
||||||
# If we pass only one argument to the script and it's the path to a json file,
|
|
||||||
# let's parse it to get our arguments.
|
|
||||||
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]), command_line_args=sys.argv[2:])
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, delta_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
|
||||||
|
|
||||||
|
# You can provide a json file with contains the arguments and use the --argument some_arg to override or append to the json file.
|
||||||
|
json_file, cmd_args = (os.path.abspath(sys.argv[1]), sys.argv[2:]) if sys.argv[1].endswith(".json") else (None, sys.argv[1:])
|
||||||
|
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file_with_cmd_args(json_file=json_file, command_line_args=cmd_args)
|
||||||
|
logger.warning("The following arguments not used! {}".format(remain_args))
|
||||||
|
|
||||||
print(f"{training_args.output_dir}/results.json")
|
# # exit()
|
||||||
# exit()
|
# # Detecting last checkpoint.
|
||||||
# Detecting last checkpoint.
|
# last_checkpoint = None
|
||||||
last_checkpoint = None
|
# if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
# last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
# print("#### last_checkpoint ", last_checkpoint)
|
||||||
print("#### last_checkpoint ", last_checkpoint)
|
# if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
# '''
|
||||||
'''
|
# raise ValueError(
|
||||||
raise ValueError(
|
# f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
# "Use --overwrite_output_dir to overcome."
|
||||||
"Use --overwrite_output_dir to overcome."
|
# )
|
||||||
)
|
# '''
|
||||||
'''
|
# pass
|
||||||
pass
|
# elif last_checkpoint is not None:
|
||||||
elif last_checkpoint is not None:
|
# logger.info(
|
||||||
logger.info(
|
# f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
# "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
# )
|
||||||
)
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
@ -149,8 +141,6 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
config, tokenizer, model = get_backbone(model_args=model_args)
|
config, tokenizer, model = get_backbone(model_args=model_args)
|
||||||
|
|
||||||
# model parallelize
|
# model parallelize
|
||||||
|
@ -164,14 +154,11 @@ def main():
|
||||||
if delta_args.delta_type.lower() != "none":
|
if delta_args.delta_type.lower() != "none":
|
||||||
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
||||||
# delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
# delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
||||||
delta_model = AutoDeltaModel.from_finetuned(delta_args.finetuned_model_path, backbone_model=model)
|
delta_model = AutoDeltaModel.from_finetuned(finetuned_model_path=delta_args.finetuned_model_path, cache_dir="saved_ckpts", backbone_model=model)
|
||||||
# delta_model.freeze_module(set_state_dict = True)
|
# delta_model.freeze_module(set_state_dict = True)
|
||||||
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
performance_metrics = {}
|
performance_metrics = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
from calendar import c
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
|
@ -215,6 +218,10 @@ class DataTrainingArguments:
|
||||||
if self.test_max_target_length is None:
|
if self.test_max_target_length is None:
|
||||||
self.test_max_target_length = self.max_target_length
|
self.test_max_target_length = self.max_target_length
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeltaArguments:
|
class DeltaArguments:
|
||||||
"""
|
"""
|
||||||
|
@ -233,47 +240,152 @@ class DeltaArguments:
|
||||||
unfrozen_modules: Optional[List[str]] = field(
|
unfrozen_modules: Optional[List[str]] = field(
|
||||||
default_factory=lambda:["deltas"], metadata={"help": "the modules inside the backbone or in the delta modules that need to be unfrozen"}
|
default_factory=lambda:["deltas"], metadata={"help": "the modules inside the backbone or in the delta modules that need to be unfrozen"}
|
||||||
)
|
)
|
||||||
|
finetuned_delta_path: Optional[str] = field(
|
||||||
# Delta-type-specific arguments
|
|
||||||
# Adapter:
|
|
||||||
bottleneck_dim: Optional[int] = field(
|
|
||||||
default=24, metadata={"help": "the dimension of the bottleneck layer"}
|
|
||||||
)
|
|
||||||
finetuned_model_path: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "the path of the finetuned delta model"}
|
default=None, metadata={"help": "the path of the finetuned delta model"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def merge_arguments(self, objb):
|
||||||
|
print(objb)
|
||||||
|
self.__class__ = dataclasses.make_dataclass('DeltaArgument', fields=[(s.name, s.type, getattr(objb, s.name)) for s in dataclasses.fields(objb)], bases=(DeltaArguments,))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ['--backbone_model', 't5', '--bottleneck_dim', '24', '--delta_type', 'adapter', '--model_path_public', 't5-base', '--unfrozen_modules', "['deltas',", "'layer_norm',", "'final_layer_norm']"]
|
@dataclass
|
||||||
|
class AdapterArguments:
|
||||||
|
bottleneck_dim: Optional[int] = field(
|
||||||
|
default=24, metadata={"help": "the dimension of the bottleneck layer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DELTAARGMAP = {
|
||||||
|
"adapter": AdapterArguments
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: add more specific delta arguments
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RemainArgHfArgumentParser(HfArgumentParser):
|
class RemainArgHfArgumentParser(HfArgumentParser):
|
||||||
def parse_json_file(self, json_file: str, command_line_args=None, return_remaining_args=True ):
|
'''This is a more powerful version of argument parser.
|
||||||
|
It can receiven both command line arguments and json file arguments.
|
||||||
|
The command line arguments will override the json file arguments.
|
||||||
|
The parser will load the specific delta arguments (e.g. Adapter's)
|
||||||
|
according to the delta_type argument. And merge the specific delta arguments
|
||||||
|
with the common delta arguments.
|
||||||
|
'''
|
||||||
|
def parse_json_file_with_cmd_args(self, json_file: str, command_line_args=None, return_remaining_args=True ):
|
||||||
"""
|
"""
|
||||||
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
||||||
dataclass types.
|
dataclass types.
|
||||||
"""
|
"""
|
||||||
import argparse
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
print("Here", command_line_args)
|
|
||||||
|
|
||||||
data = json.loads(Path(json_file).read_text())
|
data = json.loads(Path(json_file).read_text())
|
||||||
|
|
||||||
|
|
||||||
data_str = ""
|
data_str = ""
|
||||||
|
if command_line_args is None:
|
||||||
|
command_line_args = []
|
||||||
for key in data:
|
for key in data:
|
||||||
if "--"+key not in command_line_args:
|
if "--"+key not in command_line_args:
|
||||||
data_str+= "--" + key + " " + str(data[key]) + " "
|
if isinstance(data[key], list):
|
||||||
|
data_str += "--"+key
|
||||||
|
for elem in data[key]:
|
||||||
|
data_str+=" "+ str(elem)
|
||||||
|
data_str += " "
|
||||||
|
else:
|
||||||
|
data_str+= "--" + key + " " + str(data[key]) + " "
|
||||||
|
|
||||||
data_list = data_str.split()
|
data_list = data_str.split()
|
||||||
data_list += command_line_args
|
data_list += command_line_args
|
||||||
|
|
||||||
return self.parse_args_into_dataclasses(args=data_list, return_remaining_strings=return_remaining_args)
|
|
||||||
|
if return_remaining_args:
|
||||||
|
outputs, remain_args = self.parse_args_into_dataclasses(args=data_list, return_remaining_strings=return_remaining_args)
|
||||||
|
for d in outputs:
|
||||||
|
if isinstance(d, DeltaArguments): # merge the specific delta arguments
|
||||||
|
d.merge_arguments(outputs[-1])
|
||||||
|
return *(outputs[:-1]), remain_args
|
||||||
|
else:
|
||||||
|
outputs = self.parse_args_into_dataclasses(args=data_list, return_remaining_strings=return_remaining_args)
|
||||||
|
for d in outputs:
|
||||||
|
if isinstance(d, DeltaArguments):
|
||||||
|
d.merge_arguments(outputs[-1])
|
||||||
|
return (*(outputs[:-1]),)
|
||||||
|
|
||||||
|
def parse_args_into_dataclasses(
|
||||||
|
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parse command-line args into instances of the specified dataclass types.
|
||||||
|
|
||||||
|
This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
|
||||||
|
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args:
|
||||||
|
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
|
||||||
|
return_remaining_strings:
|
||||||
|
If true, also return a list of remaining argument strings.
|
||||||
|
look_for_args_file:
|
||||||
|
If true, will look for a ".args" file with the same base name as the entry point script for this
|
||||||
|
process, and will append its potential content to the command line args.
|
||||||
|
args_filename:
|
||||||
|
If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple consisting of:
|
||||||
|
|
||||||
|
- the dataclass instances in the same order as they were passed to the initializer.abspath
|
||||||
|
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
|
||||||
|
after initialization.
|
||||||
|
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
||||||
|
"""
|
||||||
|
if args_filename or (look_for_args_file and len(sys.argv)):
|
||||||
|
if args_filename:
|
||||||
|
args_file = Path(args_filename)
|
||||||
|
else:
|
||||||
|
args_file = Path(sys.argv[0]).with_suffix(".args")
|
||||||
|
|
||||||
|
if args_file.exists():
|
||||||
|
fargs = args_file.read_text().split()
|
||||||
|
args = fargs + args if args is not None else fargs + sys.argv[1:]
|
||||||
|
# in case of duplicate arguments the first one has precedence
|
||||||
|
# so we append rather than prepend.
|
||||||
|
namespace, remaining_args = self.parse_known_args(args=args)
|
||||||
|
|
||||||
|
# conditionally add delta arguments
|
||||||
|
deltatype_args = DELTAARGMAP[namespace.delta_type]
|
||||||
|
self.dataclass_types.append(deltatype_args)
|
||||||
|
self._add_dataclass_arguments(deltatype_args)
|
||||||
|
|
||||||
|
# parse the arguments again, this time with the specific delta type's arguments
|
||||||
|
namespace, remaining_args = self.parse_known_args(args=args)
|
||||||
|
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for dtype in self.dataclass_types:
|
||||||
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||||
|
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
|
||||||
|
for k in keys:
|
||||||
|
delattr(namespace, k)
|
||||||
|
obj = dtype(**inputs)
|
||||||
|
outputs.append(obj)
|
||||||
|
if len(namespace.__dict__) > 0:
|
||||||
|
# additional namespace.
|
||||||
|
outputs.append(namespace)
|
||||||
|
if return_remaining_strings:
|
||||||
|
return (outputs, remaining_args)
|
||||||
|
else:
|
||||||
|
if remaining_args:
|
||||||
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
# namespace, remaining_args = self.parse_known_args(args=data_list)
|
# namespace, remaining_args = self.parse_known_args(args=data_list)
|
||||||
|
|
||||||
|
@ -292,4 +404,6 @@ class RemainArgHfArgumentParser(HfArgumentParser):
|
||||||
# if return_remaining_args:
|
# if return_remaining_args:
|
||||||
# return (*outputs, remain_args)
|
# return (*outputs, remain_args)
|
||||||
# else:
|
# else:
|
||||||
# return (*outputs,)
|
# return (*outputs,)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DELTA_CONFIG_MAPPING = {
|
DELTA_CONFIG_MAPPING = {
|
||||||
"lora": "LoraConfig",
|
"lora": "LoraConfig",
|
||||||
"low_rank_adapter": "LowRankAdapterConfig",
|
"low_rank_adapter": "LowRankAdapterConfig",
|
||||||
"bitfit": "BitFitConfig",
|
"bitfit": "BitFitConfig",
|
||||||
"adapter":"AdapterConfig",
|
"adapter":"AdapterConfig",
|
||||||
|
@ -91,18 +91,18 @@ class AutoDeltaConfig:
|
||||||
"AutoConfig is designed to be instantiated "
|
"AutoConfig is designed to be instantiated "
|
||||||
"using the ``AutoConfig.from_pretrained(pretrained_model_name_or_path)`` method."
|
"using the ``AutoConfig.from_pretrained(pretrained_model_name_or_path)`` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
||||||
r""" Instantiate a DeltaConfig according to the dict. Automatically load the config specified by
|
r""" Instantiate a DeltaConfig according to the dict. Automatically load the config specified by
|
||||||
:obj:`delta_type`.
|
:obj:`delta_type`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_dict (:obj:`dict`): The dict of configs of delta model.
|
config_dict (:obj:`dict`): The dict of configs of delta model.
|
||||||
kwargs: Other keyword argument pass to initialize the config.
|
kwargs: Other keyword argument pass to initialize the config.
|
||||||
|
|
||||||
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora"}) # This will load the dault lora config.
|
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora"}) # This will load the dault lora config.
|
||||||
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora", "lora_r":5}) # Will load the default lora config, with lora_r = 5
|
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora", "lora_r":5}) # Will load the default lora config, with lora_r = 5
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict = deepcopy(config_dict)
|
config_dict = deepcopy(config_dict)
|
||||||
|
@ -114,7 +114,7 @@ class AutoDeltaConfig:
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_finetuned(cls, finetuned_model_name_or_path, **kwargs):
|
def from_finetuned(cls, finetuned_model_path, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Instantiate one of the configuration classes of the library from a finetuned delta model configuration.
|
Instantiate one of the configuration classes of the library from a finetuned delta model configuration.
|
||||||
The configuration class to instantiate is selected based on the ``delta_type`` property of the config object that
|
The configuration class to instantiate is selected based on the ``delta_type`` property of the config object that
|
||||||
|
@ -122,18 +122,18 @@ class AutoDeltaConfig:
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
||||||
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||||
Can be either:
|
Can be either:
|
||||||
|
|
||||||
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
||||||
huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or
|
huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or
|
||||||
namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
|
namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
|
||||||
- A path to a *directory* containing a configuration file saved using the
|
- A path to a *directory* containing a configuration file saved using the
|
||||||
:py:meth:`DeltaBase.save_finetuned` method,
|
:py:meth:`DeltaBase.save_finetuned` method,
|
||||||
e.g., ``./my_model_directory/``.
|
e.g., ``./my_model_directory/``.
|
||||||
- A path or url to a saved configuration JSON *file*, e.g.,
|
- A path or url to a saved configuration JSON *file*, e.g.,
|
||||||
``./my_model_directory/configuration.json``.
|
``./my_model_directory/configuration.json``.
|
||||||
The last two option are not tested but inherited from huggingface.
|
The last two option are not tested but inherited from huggingface.
|
||||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
cache_dir (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
standard cache should not be used.
|
standard cache should not be used.
|
||||||
|
@ -163,9 +163,9 @@ class AutoDeltaConfig:
|
||||||
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
||||||
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||||
by the ``return_unused_kwargs`` keyword parameter.
|
by the ``return_unused_kwargs`` keyword parameter.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
@ -173,25 +173,25 @@ class AutoDeltaConfig:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs["name_or_path"] = finetuned_model_name_or_path
|
kwargs["name_or_path"] = finetuned_model_path
|
||||||
|
|
||||||
config_dict, _ = BaseDeltaConfig.get_config_dict(finetuned_model_name_or_path, **kwargs)
|
config_dict, _ = BaseDeltaConfig.get_config_dict(finetuned_model_path, **kwargs)
|
||||||
if "delta_type" in config_dict:
|
if "delta_type" in config_dict:
|
||||||
config_class = LAZY_CONFIG_MAPPING[config_dict["delta_type"]]
|
config_class = LAZY_CONFIG_MAPPING[config_dict["delta_type"]]
|
||||||
return config_class.from_dict(config_dict, **kwargs)
|
return config_class.from_dict(config_dict, **kwargs)
|
||||||
else:
|
else:
|
||||||
# Fallback: use pattern matching on the string.
|
# Fallback: use pattern matching on the string.
|
||||||
for pattern, config_class in LAZY_CONFIG_MAPPING.items():
|
for pattern, config_class in LAZY_CONFIG_MAPPING.items():
|
||||||
if pattern in str(finetuned_model_name_or_path):
|
if pattern in str(finetuned_model):
|
||||||
return config_class.from_dict(config_dict, **kwargs)
|
return config_class.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized model in {finetuned_model_name_or_path}. "
|
f"Unrecognized model in {finetuned_model_path}. "
|
||||||
f"Should have a `delta_type` key in the loaded config, or contain one of the following strings "
|
f"Should have a `delta_type` key in the loaded config, or contain one of the following strings "
|
||||||
f"in its name: {', '.join(LAZY_CONFIG_MAPPING.keys())}"
|
f"in its name: {', '.join(LAZY_CONFIG_MAPPING.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
### AutoModels below
|
### AutoModels below
|
||||||
|
|
||||||
class _LazyAutoMapping(OrderedDict):
|
class _LazyAutoMapping(OrderedDict):
|
||||||
"""
|
"""
|
||||||
|
@ -323,20 +323,20 @@ class AutoDeltaModel:
|
||||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
f"`{self.__class__.__name__}.from_config(config)` methods."
|
f"`{self.__class__.__name__}.from_config(config)` methods."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config, backbone_model, **kwargs): #-> "DeltaBase":
|
def from_config(cls, config, backbone_model, **kwargs): #-> "DeltaBase":
|
||||||
r"""Automatically instantiates a delta model based on the :obj:`config`. The delta model correspond to the delta
|
r"""Automatically instantiates a delta model based on the :obj:`config`. The delta model correspond to the delta
|
||||||
:obj:`config` will be loaded and initialized using the arguments in :obj:`config`.
|
:obj:`config` will be loaded and initialized using the arguments in :obj:`config`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Only using :meth:`from_config` method will not load the finetuned weight file (e.g., pytorch_model.bin).
|
Only using :meth:`from_config` method will not load the finetuned weight file (e.g., pytorch_model.bin).
|
||||||
Please use from_finetuned directly.
|
Please use from_finetuned directly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (:obj:`BaseDeltaConfig`):
|
config (:obj:`BaseDeltaConfig`):
|
||||||
backbone_model (:obj:`nn.Module`):
|
backbone_model (:obj:`nn.Module`):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
@ -355,31 +355,31 @@ class AutoDeltaModel:
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_finetuned(cls, finetuned_model_name_or_path, backbone_model, *model_args, **kwargs):
|
def from_finetuned(cls, finetuned_model_path, backbone_model, *model_args, **kwargs):
|
||||||
r""" Automatically instantiated a delta model and load the finetuned checkpoints based on the
|
r""" Automatically instantiated a delta model and load the finetuned checkpoints based on the
|
||||||
:obj:`finetuned_model_name_or_path`, which can either be a string pointing to a local path or a url pointint to
|
:obj:`finetuned_model_path`, which can either be a string pointing to a local path or a url pointint to
|
||||||
the delta hub. It will check the hash after loading the delta model to see whether the correct backbone and
|
the delta hub. It will check the hash after loading the delta model to see whether the correct backbone and
|
||||||
delta checkpoint are used.
|
delta checkpoint are used.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||||
Can be either:
|
Can be either:
|
||||||
|
|
||||||
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
||||||
huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or
|
huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or
|
||||||
namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
|
namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
|
||||||
- A path to a *directory* containing a configuration file saved using the
|
- A path to a *directory* containing a configuration file saved using the
|
||||||
:py:meth:`DeltaBase.save_finetuned` method,
|
:py:meth:`DeltaBase.save_finetuned` method,
|
||||||
e.g., ``./my_model_directory/``.
|
e.g., ``./my_model_directory/``.
|
||||||
- A path or url to a saved configuration JSON *file*, e.g.,
|
- A path or url to a saved configuration JSON *file*, e.g.,
|
||||||
``./my_model_directory/configuration.json``.
|
``./my_model_directory/configuration.json``.
|
||||||
The last two option are not tested but inherited from huggingface.
|
The last two option are not tested but inherited from huggingface.
|
||||||
|
|
||||||
backbone_model (:obj:`nn.Module`): The backbone model to be modified.
|
backbone_model (:obj:`nn.Module`): The backbone model to be modified.
|
||||||
model_args: Other argument for initialize the model.
|
model_args: Other argument for initialize the model.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
delta_model = AutoDeltaModel.from_finetuned("DeltaHub/lora_t5-base-mrpc", backbone_model)
|
delta_model = AutoDeltaModel.from_finetuned("DeltaHub/lora_t5-base-mrpc", backbone_model)
|
||||||
|
@ -389,19 +389,19 @@ class AutoDeltaModel:
|
||||||
|
|
||||||
if not isinstance(config, BaseDeltaConfig):
|
if not isinstance(config, BaseDeltaConfig):
|
||||||
config, kwargs = AutoDeltaConfig.from_finetuned(
|
config, kwargs = AutoDeltaConfig.from_finetuned(
|
||||||
finetuned_model_name_or_path, return_unused_kwargs=True, **kwargs
|
finetuned_model_path, return_unused_kwargs=True, **kwargs
|
||||||
)
|
)
|
||||||
if type(config) in cls._delta_model_mapping.keys():
|
if type(config) in cls._delta_model_mapping.keys():
|
||||||
model_class = cls._delta_model_mapping[type(config)]
|
model_class = cls._delta_model_mapping[type(config)]
|
||||||
return model_class.from_finetuned(finetuned_model_name_or_path, backbone_model, *model_args, **kwargs)
|
return model_class.from_finetuned(finetuned_model_path, backbone_model, *model_args, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
|
@ -108,7 +108,7 @@ class BaseDeltaConfig(PushToHubMixin):
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_finetuned(cls, finetuned_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
|
def from_finetuned(cls, finetuned_model_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
|
||||||
r"""
|
r"""
|
||||||
Instantiate a :obj:`BaseDeltaConfig` (or a derived class) from a finetined delta module configuration.
|
Instantiate a :obj:`BaseDeltaConfig` (or a derived class) from a finetined delta module configuration.
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ class BaseDeltaConfig(PushToHubMixin):
|
||||||
delta_config = LoraConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
|
delta_config = LoraConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict, kwargs = cls.get_config_dict(finetuned_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(finetuned_model_path, **kwargs)
|
||||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
|
@ -254,24 +254,29 @@ class BaseDeltaConfig(PushToHubMixin):
|
||||||
local_files_only = True
|
local_files_only = True
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
|
||||||
|
|
||||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
configuration_file = get_configuration_file(
|
# configuration_file = get_configuration_file(
|
||||||
pretrained_model_name_or_path,
|
# pretrained_model_name_or_path,
|
||||||
revision=revision,
|
# revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
# use_auth_token=use_auth_token,
|
||||||
local_files_only=local_files_only,
|
# local_files_only=local_files_only,
|
||||||
)
|
# )
|
||||||
|
print("cache_dir", cache_dir, "|||", "pretrained_model_name_or_path", pretrained_model_name_or_path)
|
||||||
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||||
|
elif os.path.isdir(os.path.join(cache_dir, pretrained_model_name_or_path)):
|
||||||
|
config_file = os.path.join(cache_dir, pretrained_model_name_or_path, "config.json")
|
||||||
else:
|
else:
|
||||||
config_file = hf_bucket_url(
|
config_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("config file!!", config_file)
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_config_file = cached_path(
|
resolved_config_file = cached_path(
|
||||||
|
|
|
@ -11,6 +11,8 @@ from opendelta import BaseDeltaConfig
|
||||||
import opendelta.utils.logging as logging
|
import opendelta.utils.logging as logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from opendelta import global_setting
|
from opendelta import global_setting
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,7 +141,7 @@ class AdapterConfig(BaseDeltaConfig):
|
||||||
self,
|
self,
|
||||||
bottleneck_dim: Optional[int]=24,
|
bottleneck_dim: Optional[int]=24,
|
||||||
non_linearity: Optional[str]='gelu_new',
|
non_linearity: Optional[str]='gelu_new',
|
||||||
sequential: Optional[str] = True,
|
sequential: Optional[bool] = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
|
@ -3,10 +3,10 @@ from typing import Optional, Union
|
||||||
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||||
from opendelta.utils.name_based_addressing import *
|
from opendelta.utils.name_based_addressing import *
|
||||||
from opendelta.basemodel import DeltaBase
|
from opendelta.basemodel import DeltaBase
|
||||||
from transformers.models.t5 import T5ForConditionalGeneration
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from opendelta import BaseDeltaConfig
|
from opendelta import BaseDeltaConfig
|
||||||
import math
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
class LowRankLinear(nn.Module):
|
class LowRankLinear(nn.Module):
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
|
@ -40,6 +40,11 @@ class LowRankLinear(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraArguments:
|
||||||
|
r: int = 8
|
||||||
|
lora_alpha: int = 16
|
||||||
|
lora_dropout: float = 0.0
|
||||||
|
|
||||||
class LoraConfig(BaseDeltaConfig):
|
class LoraConfig(BaseDeltaConfig):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -71,6 +71,10 @@ class DeltaCenterArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "the performance of the model on the test set"}
|
metadata={"help": "the performance of the model on the test set"}
|
||||||
)
|
)
|
||||||
|
test_metrics: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "the metrics used by the model"}
|
||||||
|
)
|
||||||
trainable_ratio: Optional[float] = field(
|
trainable_ratio: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "the ratio of trainable parameters in the model"}
|
metadata={"help": "the ratio of trainable parameters in the model"}
|
||||||
|
@ -93,15 +97,15 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
|
|
||||||
def save_finetuned(
|
def save_finetuned(
|
||||||
self,
|
self,
|
||||||
save_directory: Optional[Union[str, os.PathLike]] = "./output/",
|
finetuned_delta_path: Optional[Union[str, os.PathLike]] = "./delta_checkpoints/",
|
||||||
save_config: bool = True,
|
save_config: bool = True,
|
||||||
state_dict: Optional[dict] = None,
|
state_dict: Optional[dict] = None,
|
||||||
save_function: Callable = torch.save,
|
save_function: Callable = torch.save,
|
||||||
push_to_dc: bool = True,
|
push_to_dc: bool = True,
|
||||||
center_args: Optional[Union[DeltaCenterArguments, dict]] = None,
|
center_args: Optional[Union[DeltaCenterArguments, dict]] = None,
|
||||||
center_args_pool: Optional[dict] = None,
|
center_args_pool: Optional[dict] = None,
|
||||||
center_value_only_tags: Optional[List] = None,
|
list_tags: Optional[List] = None,
|
||||||
center_key_value_tags: Optional[Dict] = None,
|
dict_tags: Optional[Dict] = None,
|
||||||
delay_push: bool = False,
|
delay_push: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -129,9 +133,9 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
center_args_pool (:obj:`dict`, *optional*, defaults to :obj:`None`): The arguments's pool for DeltaCenter
|
center_args_pool (:obj:`dict`, *optional*, defaults to :obj:`None`): The arguments's pool for DeltaCenter
|
||||||
Together with center_args, they are are used to distinguish between different delta models on the DeltaCenter.
|
Together with center_args, they are are used to distinguish between different delta models on the DeltaCenter.
|
||||||
It will be used to group delta models.
|
It will be used to group delta models.
|
||||||
center_value_only_tags (:obj:`List`, *optional*, defaults to :obj:`None`): The tags in the form of list for the delta model, it is the
|
list_tags (:obj:`List`, *optional*, defaults to :obj:`None`): The tags in the form of list for the delta model, it is the
|
||||||
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center
|
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center
|
||||||
center_key_value_tags (:obj:`Dict`, *optional*, defaults to :obj:`None`): The tags in the form of dictionary for the delta model, it is the
|
dict_tags (:obj:`Dict`, *optional*, defaults to :obj:`None`): The tags in the form of dictionary for the delta model, it is the
|
||||||
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center.
|
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center.
|
||||||
delay_push (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether or not to delay the push to the DeltaCenter. When set to True,
|
delay_push (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether or not to delay the push to the DeltaCenter. When set to True,
|
||||||
the delta object will be saved locally to save_directory, you can push it later using
|
the delta object will be saved locally to save_directory, you can push it later using
|
||||||
|
@ -142,6 +146,7 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
save_directory = finetuned_delta_path
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
@ -169,26 +174,36 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
final_center_args = self.create_delta_center_args(center_args=center_args,
|
final_center_args = self.create_delta_center_args(center_args=center_args,
|
||||||
center_args_pool=center_args_pool)
|
center_args_pool=center_args_pool)
|
||||||
|
|
||||||
print("final_center_args", final_center_args)
|
|
||||||
|
|
||||||
if push_to_dc:
|
if push_to_dc:
|
||||||
self.create_yml(save_directory, final_center_args, center_value_only_tags, center_key_value_tags)
|
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
|
||||||
|
|
||||||
if not delay_push:
|
if not delay_push:
|
||||||
OssClient.upload(base_dir=save_directory)
|
OssClient.upload(base_dir=save_directory)
|
||||||
|
else:
|
||||||
|
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t\t{os.path.abspath(save_directory)}\
|
||||||
|
\nyou can push it to the delta center later using \n\t\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
||||||
|
+"*"*30)
|
||||||
|
|
||||||
|
# get absolute path of saved_directory,
|
||||||
|
|
||||||
|
|
||||||
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
|
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
|
||||||
f = open("{}/config.yml".format(save_dir), 'w')
|
f = open("{}/config.yml".format(save_dir), 'w')
|
||||||
yaml.safe_dump(vars(config), f)
|
config_dict = vars(config)
|
||||||
|
config_dict['dict_tags'] = dict_tags if dict_tags is not None else {}
|
||||||
|
config_dict['list_tags'] = list_tags if list_tags is not None else []
|
||||||
|
yaml.safe_dump(config_dict, f)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_finetuned(cls,
|
def from_finetuned(cls,
|
||||||
finetuned_model_name_or_path: Optional[Union[str, os.PathLike]],
|
finetuned_delta_path: Optional[Union[str, os.PathLike]],
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
*model_args,
|
delta_config = None,
|
||||||
check_hash: Optional[bool] = True,
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
**kwargs):
|
*model_args,
|
||||||
|
check_hash: Optional[bool] = True,
|
||||||
|
**kwargs):
|
||||||
r"""
|
r"""
|
||||||
Instantiate a finetuned delta model from a path.
|
Instantiate a finetuned delta model from a path.
|
||||||
The backbone_model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
|
The backbone_model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
|
||||||
|
@ -196,7 +211,7 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
||||||
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||||
Can be either:
|
Can be either:
|
||||||
|
|
||||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
@ -297,9 +312,9 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config = kwargs.pop("config", None)
|
# config = kwargs.pop("config", None)
|
||||||
state_dict = kwargs.pop("state_dict", None)
|
state_dict = kwargs.pop("state_dict", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
# cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
|
||||||
# ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
# ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
@ -323,10 +338,10 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
local_files_only = True
|
local_files_only = True
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
if not isinstance(config, BaseDeltaConfig):
|
if not isinstance(delta_config, BaseDeltaConfig):
|
||||||
config_path = config if config is not None else finetuned_model_name_or_path
|
# config_path = delta_config if delta_config is not None else finetuned_model_path # Todo check
|
||||||
config, model_kwargs = cls.config_class.from_finetuned(
|
delta_config, model_kwargs = cls.config_class.from_finetuned(
|
||||||
config_path,
|
finetuned_model_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
return_unused_kwargs=True,
|
return_unused_kwargs=True,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
|
@ -343,23 +358,24 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
print("delta_config", delta_config)
|
||||||
# Load model
|
# Load model
|
||||||
if finetuned_model_name_or_path is not None:
|
if finetuned_model_path is not None:
|
||||||
finetuned_model_name_or_path = str(finetuned_model_name_or_path)
|
finetuned_model_path = str(finetuned_model_path)
|
||||||
if os.path.isdir(finetuned_model_name_or_path):
|
if os.path.isdir(finetuned_model_path):
|
||||||
if os.path.isfile(os.path.join(finetuned_model_name_or_path, WEIGHTS_NAME)):
|
if os.path.isfile(os.path.join(finetuned_model_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(finetuned_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(finetuned_model_path, WEIGHTS_NAME)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {WEIGHTS_NAME} found in "
|
f"Error no file named {WEIGHTS_NAME} found in "
|
||||||
f"directory {finetuned_model_name_or_path}."
|
f"directory {finetuned_model_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(finetuned_model_name_or_path) or is_remote_url(finetuned_model_name_or_path):
|
elif os.path.isfile(finetuned_model_path) or is_remote_url(finetuned_model_path):
|
||||||
archive_file = finetuned_model_name_or_path
|
archive_file = finetuned_model_path
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
finetuned_model_name_or_path,
|
finetuned_model_path,
|
||||||
filename=WEIGHTS_NAME,
|
filename=WEIGHTS_NAME,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
mirror=mirror,
|
mirror=mirror,
|
||||||
|
@ -381,7 +397,7 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
msg = (
|
msg = (
|
||||||
f"Can't load weights for '{finetuned_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load weights for '{finetuned_model_path}'. Make sure that:\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if revision is not None:
|
if revision is not None:
|
||||||
|
@ -414,7 +430,7 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
raise ValueError from e
|
raise ValueError from e
|
||||||
except (UnicodeDecodeError, ValueError):
|
except (UnicodeDecodeError, ValueError):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"Unable to load weights from pytorch checkpoint file for '{finetuned_model_name_or_path}' "
|
f"Unable to load weights from pytorch checkpoint file for '{finetuned_model_path}' "
|
||||||
f"at '{resolved_archive_file}'. "
|
f"at '{resolved_archive_file}'. "
|
||||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||||
)
|
)
|
||||||
|
@ -469,7 +485,6 @@ class SaveLoadMixin(PushToHubMixin):
|
||||||
mdict = {}
|
mdict = {}
|
||||||
field = fields(DeltaCenterArguments)
|
field = fields(DeltaCenterArguments)
|
||||||
|
|
||||||
print("center_args_pool", center_args_pool)
|
|
||||||
|
|
||||||
for f in field:
|
for f in field:
|
||||||
exist = False
|
exist = False
|
||||||
|
|
|
@ -4,10 +4,10 @@ from collections import namedtuple
|
||||||
def signature(f):
|
def signature(f):
|
||||||
r"""Get the function f 's input arguments. A useful gadget
|
r"""Get the function f 's input arguments. A useful gadget
|
||||||
when some function slot might be instantiated into multiple functions.
|
when some function slot might be instantiated into multiple functions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f (:obj:`function`) : the function to get the input arguments.
|
f (:obj:`function`) : the function to get the input arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
namedtuple : of args, default, varargs, keywords, respectively.s
|
namedtuple : of args, default, varargs, keywords, respectively.s
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ def signature(f):
|
||||||
] or None
|
] or None
|
||||||
argspec = namedtuple('Signature', ['args', 'defaults',
|
argspec = namedtuple('Signature', ['args', 'defaults',
|
||||||
'varargs', 'keywords'])
|
'varargs', 'keywords'])
|
||||||
return argspec(args, defaults, varargs, keywords)
|
return argspec(args, defaults, varargs, keywords)
|
||||||
|
|
||||||
def get_arg_names(f):
|
def get_arg_names(f):
|
||||||
r""" Get a functions argument name, remove the ``self`` argument
|
r""" Get a functions argument name, remove the ``self`` argument
|
||||||
|
@ -45,6 +45,7 @@ def get_arg_names(f):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_arg_names_inside_func(func):
|
def get_arg_names_inside_func(func):
|
||||||
r""" Get the functions argument name inside the function itself. Remove ``self`` argument.
|
r""" Get the functions argument name inside the function itself. Remove ``self`` argument.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue