update_from_finetuned

This commit is contained in:
shengdinghu 2022-07-01 22:23:02 +08:00
parent 7d256ce288
commit 617955e08e
12 changed files with 294 additions and 166 deletions

1
.gitignore vendored
View File

@ -37,4 +37,5 @@ log.txt
**/Delta_Memory/
**/output/
**/thunlp/
**/saved_ckpts/

View File

@ -343,7 +343,7 @@ class SuperGLUECB(AbstractTask):
"validation": "validation",
"test": "validation"}
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
metric_names = ["f1_multiclass", "accuracy"]
metric_names = ["accuracy"]
verbalizers = {
"0":{"0": "yes",

View File

@ -47,20 +47,20 @@ def spearman_corrcoef(predictions, targets) -> dict:
def spearman_corrcoef(predictions, targets) -> dict:
"""Computes Spearman correlation coefficient."""
# TODO: we need to do postprocessors in a clean way for each dataset.
from examples_seq2seq.data_processors.postprocessors import string_to_float
targets = [string_to_float(target) for target in targets]
predictions= [string_to_float(prediction) for prediction in predictions]
spearman_corrcoef = 100 * scipy.stats.spearmanr(targets, predictions)[0]
# def spearman_corrcoef(predictions, targets) -> dict:
# """Computes Spearman correlation coefficient."""
# # TODO: we need to do postprocessors in a clean way for each dataset.
# from examples_seq2seq.data_processors.postprocessors import string_to_float
# targets = [string_to_float(target) for target in targets]
# predictions= [string_to_float(prediction) for prediction in predictions]
# spearman_corrcoef = 100 * scipy.stats.spearmanr(targets, predictions)[0]
# Note that if all the predictions will be the same, spearman
# correlation is nan, to gaurad against this, we check the output
# and return 0 in this case.
if math.isnan(spearman_corrcoef):
spearman_corrcoef = 0
return {"spearmanr": spearman_corrcoef}
# # Note that if all the predictions will be the same, spearman
# # correlation is nan, to gaurad against this, we check the output
# # and return 0 in this case.
# if math.isnan(spearman_corrcoef):
# spearman_corrcoef = 0
# return {"spearmanr": spearman_corrcoef}
def f1_score_with_invalid(predictions, targets) -> dict:

View File

@ -26,6 +26,7 @@ You can also adapt this script on your own tasks.
import os
import sys
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -56,7 +57,7 @@ from transformers.trainer_utils import is_main_process, get_last_checkpoint
from data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
from utils import read_json, save_json
from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, RemainArgHfArgumentParser
from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, DeltaArguments, RemainArgHfArgumentParser
logger = logging.getLogger(__name__)
@ -66,16 +67,14 @@ def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, delta_args, remain_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
# You can provide a json file with contains the arguments and use the --argument some_arg to override or append to the json file.
json_file, cmd_args = (os.path.abspath(sys.argv[1]), sys.argv[2:]) if sys.argv[1].endswith(".json") else (None, sys.argv[1:])
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file_with_cmd_args(json_file=json_file, command_line_args=cmd_args)
logger.warning("The following arguments not used! {}".format(remain_args))
print(f"{training_args.output_dir}/results.json")
logger.info(f"The results will be used in {training_args.output_dir}/results.json")
# exit()
# Detecting last checkpoint.
last_checkpoint = None
@ -161,7 +160,8 @@ def main():
if delta_args.delta_type.lower() != "none":
from opendelta import AutoDeltaConfig,AutoDeltaModel
delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
from dataclasses import asdict
delta_config = AutoDeltaConfig.from_dict(asdict(delta_args))
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=model)
delta_model.freeze_module(set_state_dict = True)
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
@ -278,14 +278,9 @@ def main():
if torch.cuda.is_available() and training_args.compute_memory:
peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
print(
"Memory utilization",
peak_memory,
"GB"
)
performance_metrics.update({"peak_memory": peak_memory})
if training_args.compute_memory or training_args.compute_time:
print("Efficiency Statistics {}".format(performance_metrics))
logger.info("Efficiency Statistics {}".format(performance_metrics))
trainer.save_metrics("performance", performance_metrics)
# Evaluation
@ -326,10 +321,13 @@ def main():
# all_results['repo_name'] = repo_name
delta_model.save_finetuned(push_to_hf=training_args.push_to_hf,
delta_model.save_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path,
push_to_dc=training_args.push_to_dc,
center_args={},
center_args={"test_performance":all_results['test'][data_args.task_name]['test_average_metrics'],
},
center_args_pool = {**vars(model_args), **vars(data_args), **vars(training_args), **vars(delta_args)},
list_tags = ['NLI'],
dict_tags = {'purpose':'for testing'},
delay_push=True,
)

View File

@ -63,40 +63,32 @@ logger = logging.getLogger(__name__)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
print("here1")
if sys.argv[1].endswith(".json"):
print("here2", sys.argv[2:])
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]), command_line_args=sys.argv[2:])
else:
model_args, data_args, training_args, delta_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
# You can provide a json file with contains the arguments and use the --argument some_arg to override or append to the json file.
json_file, cmd_args = (os.path.abspath(sys.argv[1]), sys.argv[2:]) if sys.argv[1].endswith(".json") else (None, sys.argv[1:])
model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file_with_cmd_args(json_file=json_file, command_line_args=cmd_args)
logger.warning("The following arguments not used! {}".format(remain_args))
print(f"{training_args.output_dir}/results.json")
# exit()
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
print("#### last_checkpoint ", last_checkpoint)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
'''
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
'''
pass
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# # exit()
# # Detecting last checkpoint.
# last_checkpoint = None
# if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
# last_checkpoint = get_last_checkpoint(training_args.output_dir)
# print("#### last_checkpoint ", last_checkpoint)
# if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
# '''
# raise ValueError(
# f"Output directory ({training_args.output_dir}) already exists and is not empty. "
# "Use --overwrite_output_dir to overcome."
# )
# '''
# pass
# elif last_checkpoint is not None:
# logger.info(
# f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
# "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
# )
# Setup logging
logging.basicConfig(
@ -149,8 +141,6 @@ def main():
config, tokenizer, model = get_backbone(model_args=model_args)
# model parallelize
@ -164,14 +154,11 @@ def main():
if delta_args.delta_type.lower() != "none":
from opendelta import AutoDeltaConfig,AutoDeltaModel
# delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
delta_model = AutoDeltaModel.from_finetuned(delta_args.finetuned_model_path, backbone_model=model)
delta_model = AutoDeltaModel.from_finetuned(finetuned_model_path=delta_args.finetuned_model_path, cache_dir="saved_ckpts", backbone_model=model)
# delta_model.freeze_module(set_state_dict = True)
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
performance_metrics = {}

View File

@ -1,7 +1,10 @@
from calendar import c
from dataclasses import dataclass, field
from typing import Optional, List
from transformers import HfArgumentParser
from pathlib import Path
import sys
@dataclass
class ModelArguments:
@ -215,6 +218,10 @@ class DataTrainingArguments:
if self.test_max_target_length is None:
self.test_max_target_length = self.max_target_length
import dataclasses
@dataclass
class DeltaArguments:
"""
@ -233,47 +240,152 @@ class DeltaArguments:
unfrozen_modules: Optional[List[str]] = field(
default_factory=lambda:["deltas"], metadata={"help": "the modules inside the backbone or in the delta modules that need to be unfrozen"}
)
# Delta-type-specific arguments
# Adapter:
bottleneck_dim: Optional[int] = field(
default=24, metadata={"help": "the dimension of the bottleneck layer"}
)
finetuned_model_path: Optional[str] = field(
finetuned_delta_path: Optional[str] = field(
default=None, metadata={"help": "the path of the finetuned delta model"}
)
def merge_arguments(self, objb):
print(objb)
self.__class__ = dataclasses.make_dataclass('DeltaArgument', fields=[(s.name, s.type, getattr(objb, s.name)) for s in dataclasses.fields(objb)], bases=(DeltaArguments,))
# ['--backbone_model', 't5', '--bottleneck_dim', '24', '--delta_type', 'adapter', '--model_path_public', 't5-base', '--unfrozen_modules', "['deltas',", "'layer_norm',", "'final_layer_norm']"]
@dataclass
class AdapterArguments:
bottleneck_dim: Optional[int] = field(
default=24, metadata={"help": "the dimension of the bottleneck layer"}
)
DELTAARGMAP = {
"adapter": AdapterArguments
}
# TODO: add more specific delta arguments
class RemainArgHfArgumentParser(HfArgumentParser):
def parse_json_file(self, json_file: str, command_line_args=None, return_remaining_args=True ):
'''This is a more powerful version of argument parser.
It can receiven both command line arguments and json file arguments.
The command line arguments will override the json file arguments.
The parser will load the specific delta arguments (e.g. Adapter's)
according to the delta_type argument. And merge the specific delta arguments
with the common delta arguments.
'''
def parse_json_file_with_cmd_args(self, json_file: str, command_line_args=None, return_remaining_args=True ):
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
"""
import argparse
import json
from pathlib import Path
import dataclasses
print("Here", command_line_args)
data = json.loads(Path(json_file).read_text())
data_str = ""
if command_line_args is None:
command_line_args = []
for key in data:
if "--"+key not in command_line_args:
data_str+= "--" + key + " " + str(data[key]) + " "
if isinstance(data[key], list):
data_str += "--"+key
for elem in data[key]:
data_str+=" "+ str(elem)
data_str += " "
else:
data_str+= "--" + key + " " + str(data[key]) + " "
data_list = data_str.split()
data_list += command_line_args
return self.parse_args_into_dataclasses(args=data_list, return_remaining_strings=return_remaining_args)
if return_remaining_args:
outputs, remain_args = self.parse_args_into_dataclasses(args=data_list, return_remaining_strings=return_remaining_args)
for d in outputs:
if isinstance(d, DeltaArguments): # merge the specific delta arguments
d.merge_arguments(outputs[-1])
return *(outputs[:-1]), remain_args
else:
outputs = self.parse_args_into_dataclasses(args=data_list, return_remaining_strings=return_remaining_args)
for d in outputs:
if isinstance(d, DeltaArguments):
d.merge_arguments(outputs[-1])
return (*(outputs[:-1]),)
def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
):
"""
Parse command-line args into instances of the specified dataclass types.
This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
Args:
args:
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
return_remaining_strings:
If true, also return a list of remaining argument strings.
look_for_args_file:
If true, will look for a ".args" file with the same base name as the entry point script for this
process, and will append its potential content to the command line args.
args_filename:
If not None, will uses this file instead of the ".args" file specified in the previous argument.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.abspath
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
after initialization.
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
"""
if args_filename or (look_for_args_file and len(sys.argv)):
if args_filename:
args_file = Path(args_filename)
else:
args_file = Path(sys.argv[0]).with_suffix(".args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.
namespace, remaining_args = self.parse_known_args(args=args)
# conditionally add delta arguments
deltatype_args = DELTAARGMAP[namespace.delta_type]
self.dataclass_types.append(deltatype_args)
self._add_dataclass_arguments(deltatype_args)
# parse the arguments again, this time with the specific delta type's arguments
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)
obj = dtype(**inputs)
outputs.append(obj)
if len(namespace.__dict__) > 0:
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (outputs, remaining_args)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
return outputs
# namespace, remaining_args = self.parse_known_args(args=data_list)
@ -292,4 +404,6 @@ class RemainArgHfArgumentParser(HfArgumentParser):
# if return_remaining_args:
# return (*outputs, remain_args)
# else:
# return (*outputs,)
# return (*outputs,)

View File

@ -11,7 +11,7 @@ logger = get_logger(__name__)
DELTA_CONFIG_MAPPING = {
"lora": "LoraConfig",
"lora": "LoraConfig",
"low_rank_adapter": "LowRankAdapterConfig",
"bitfit": "BitFitConfig",
"adapter":"AdapterConfig",
@ -91,18 +91,18 @@ class AutoDeltaConfig:
"AutoConfig is designed to be instantiated "
"using the ``AutoConfig.from_pretrained(pretrained_model_name_or_path)`` method."
)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
r""" Instantiate a DeltaConfig according to the dict. Automatically load the config specified by
r""" Instantiate a DeltaConfig according to the dict. Automatically load the config specified by
:obj:`delta_type`.
Args:
config_dict (:obj:`dict`): The dict of configs of delta model.
kwargs: Other keyword argument pass to initialize the config.
kwargs: Other keyword argument pass to initialize the config.
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora"}) # This will load the dault lora config.
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora", "lora_r":5}) # Will load the default lora config, with lora_r = 5
>>> config = AutoDeltaConfig.from_dict({"delta_type":"lora", "lora_r":5}) # Will load the default lora config, with lora_r = 5
"""
config_dict = deepcopy(config_dict)
@ -114,7 +114,7 @@ class AutoDeltaConfig:
@classmethod
def from_finetuned(cls, finetuned_model_name_or_path, **kwargs):
def from_finetuned(cls, finetuned_model_path, **kwargs):
r"""
Instantiate one of the configuration classes of the library from a finetuned delta model configuration.
The configuration class to instantiate is selected based on the ``delta_type`` property of the config object that
@ -122,18 +122,18 @@ class AutoDeltaConfig:
Parameters:
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or
namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
- A path to a *directory* containing a configuration file saved using the
:py:meth:`DeltaBase.save_finetuned` method,
:py:meth:`DeltaBase.save_finetuned` method,
e.g., ``./my_model_directory/``.
- A path or url to a saved configuration JSON *file*, e.g.,
``./my_model_directory/configuration.json``.
The last two option are not tested but inherited from huggingface.
The last two option are not tested but inherited from huggingface.
cache_dir (:obj:`str` or :obj:`os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
@ -163,9 +163,9 @@ class AutoDeltaConfig:
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the ``return_unused_kwargs`` keyword parameter.
Examples:
.. code-block:: python
from transformers import AutoConfig
@ -173,25 +173,25 @@ class AutoDeltaConfig:
"""
kwargs["name_or_path"] = finetuned_model_name_or_path
kwargs["name_or_path"] = finetuned_model_path
config_dict, _ = BaseDeltaConfig.get_config_dict(finetuned_model_name_or_path, **kwargs)
config_dict, _ = BaseDeltaConfig.get_config_dict(finetuned_model_path, **kwargs)
if "delta_type" in config_dict:
config_class = LAZY_CONFIG_MAPPING[config_dict["delta_type"]]
return config_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, config_class in LAZY_CONFIG_MAPPING.items():
if pattern in str(finetuned_model_name_or_path):
if pattern in str(finetuned_model):
return config_class.from_dict(config_dict, **kwargs)
raise ValueError(
f"Unrecognized model in {finetuned_model_name_or_path}. "
f"Unrecognized model in {finetuned_model_path}. "
f"Should have a `delta_type` key in the loaded config, or contain one of the following strings "
f"in its name: {', '.join(LAZY_CONFIG_MAPPING.keys())}"
)
### AutoModels below
### AutoModels below
class _LazyAutoMapping(OrderedDict):
"""
@ -323,20 +323,20 @@ class AutoDeltaModel:
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)
@classmethod
def from_config(cls, config, backbone_model, **kwargs): #-> "DeltaBase":
r"""Automatically instantiates a delta model based on the :obj:`config`. The delta model correspond to the delta
:obj:`config` will be loaded and initialized using the arguments in :obj:`config`.
:obj:`config` will be loaded and initialized using the arguments in :obj:`config`.
.. note::
Only using :meth:`from_config` method will not load the finetuned weight file (e.g., pytorch_model.bin).
Please use from_finetuned directly.
Only using :meth:`from_config` method will not load the finetuned weight file (e.g., pytorch_model.bin).
Please use from_finetuned directly.
Args:
config (:obj:`BaseDeltaConfig`):
backbone_model (:obj:`nn.Module`):
Examples:
.. code-block:: python
@ -355,31 +355,31 @@ class AutoDeltaModel:
)
@classmethod
def from_finetuned(cls, finetuned_model_name_or_path, backbone_model, *model_args, **kwargs):
r""" Automatically instantiated a delta model and load the finetuned checkpoints based on the
:obj:`finetuned_model_name_or_path`, which can either be a string pointing to a local path or a url pointint to
the delta hub. It will check the hash after loading the delta model to see whether the correct backbone and
delta checkpoint are used.
def from_finetuned(cls, finetuned_model_path, backbone_model, *model_args, **kwargs):
r""" Automatically instantiated a delta model and load the finetuned checkpoints based on the
:obj:`finetuned_model_path`, which can either be a string pointing to a local path or a url pointint to
the delta hub. It will check the hash after loading the delta model to see whether the correct backbone and
delta checkpoint are used.
Args:
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or
namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
- A path to a *directory* containing a configuration file saved using the
:py:meth:`DeltaBase.save_finetuned` method,
:py:meth:`DeltaBase.save_finetuned` method,
e.g., ``./my_model_directory/``.
- A path or url to a saved configuration JSON *file*, e.g.,
``./my_model_directory/configuration.json``.
The last two option are not tested but inherited from huggingface.
The last two option are not tested but inherited from huggingface.
backbone_model (:obj:`nn.Module`): The backbone model to be modified.
model_args: Other argument for initialize the model.
Example:
.. code-block:: python
delta_model = AutoDeltaModel.from_finetuned("DeltaHub/lora_t5-base-mrpc", backbone_model)
@ -389,19 +389,19 @@ class AutoDeltaModel:
if not isinstance(config, BaseDeltaConfig):
config, kwargs = AutoDeltaConfig.from_finetuned(
finetuned_model_name_or_path, return_unused_kwargs=True, **kwargs
finetuned_model_path, return_unused_kwargs=True, **kwargs
)
if type(config) in cls._delta_model_mapping.keys():
model_class = cls._delta_model_mapping[type(config)]
return model_class.from_finetuned(finetuned_model_name_or_path, backbone_model, *model_args, **kwargs)
return model_class.from_finetuned(finetuned_model_path, backbone_model, *model_args, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
if __name__ == "__main__":

View File

@ -108,7 +108,7 @@ class BaseDeltaConfig(PushToHubMixin):
@classmethod
def from_finetuned(cls, finetuned_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
def from_finetuned(cls, finetuned_model_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
r"""
Instantiate a :obj:`BaseDeltaConfig` (or a derived class) from a finetined delta module configuration.
@ -132,7 +132,7 @@ class BaseDeltaConfig(PushToHubMixin):
delta_config = LoraConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
"""
config_dict, kwargs = cls.get_config_dict(finetuned_model_name_or_path, **kwargs)
config_dict, kwargs = cls.get_config_dict(finetuned_model_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warn(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
@ -254,24 +254,29 @@ class BaseDeltaConfig(PushToHubMixin):
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
configuration_file = get_configuration_file(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
# configuration_file = get_configuration_file(
# pretrained_model_name_or_path,
# revision=revision,
# use_auth_token=use_auth_token,
# local_files_only=local_files_only,
# )
print("cache_dir", cache_dir, "|||", "pretrained_model_name_or_path", pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
elif os.path.isdir(os.path.join(cache_dir, pretrained_model_name_or_path)):
config_file = os.path.join(cache_dir, pretrained_model_name_or_path, "config.json")
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
)
print("config file!!", config_file)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(

View File

@ -11,6 +11,8 @@ from opendelta import BaseDeltaConfig
import opendelta.utils.logging as logging
import numpy as np
from opendelta import global_setting
from dataclasses import dataclass, field
logger = logging.get_logger(__name__)
@ -139,7 +141,7 @@ class AdapterConfig(BaseDeltaConfig):
self,
bottleneck_dim: Optional[int]=24,
non_linearity: Optional[str]='gelu_new',
sequential: Optional[str] = True,
sequential: Optional[bool] = True,
**kwargs
):
super().__init__(**kwargs)

View File

@ -3,10 +3,10 @@ from typing import Optional, Union
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase
from transformers.models.t5 import T5ForConditionalGeneration
import torch.nn as nn
from opendelta import BaseDeltaConfig
import math
from dataclasses import dataclass, field
class LowRankLinear(nn.Module):
# ------------------------------------------------------------------------------------------
@ -40,6 +40,11 @@ class LowRankLinear(nn.Module):
def forward(self, x):
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
@dataclass
class LoraArguments:
r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.0
class LoraConfig(BaseDeltaConfig):
r"""

View File

@ -71,6 +71,10 @@ class DeltaCenterArguments:
default=None,
metadata={"help": "the performance of the model on the test set"}
)
test_metrics: Optional[str] = field(
default=None,
metadata={"help": "the metrics used by the model"}
)
trainable_ratio: Optional[float] = field(
default=None,
metadata={"help": "the ratio of trainable parameters in the model"}
@ -93,15 +97,15 @@ class SaveLoadMixin(PushToHubMixin):
def save_finetuned(
self,
save_directory: Optional[Union[str, os.PathLike]] = "./output/",
finetuned_delta_path: Optional[Union[str, os.PathLike]] = "./delta_checkpoints/",
save_config: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_dc: bool = True,
center_args: Optional[Union[DeltaCenterArguments, dict]] = None,
center_args_pool: Optional[dict] = None,
center_value_only_tags: Optional[List] = None,
center_key_value_tags: Optional[Dict] = None,
list_tags: Optional[List] = None,
dict_tags: Optional[Dict] = None,
delay_push: bool = False,
):
r"""
@ -129,9 +133,9 @@ class SaveLoadMixin(PushToHubMixin):
center_args_pool (:obj:`dict`, *optional*, defaults to :obj:`None`): The arguments's pool for DeltaCenter
Together with center_args, they are are used to distinguish between different delta models on the DeltaCenter.
It will be used to group delta models.
center_value_only_tags (:obj:`List`, *optional*, defaults to :obj:`None`): The tags in the form of list for the delta model, it is the
list_tags (:obj:`List`, *optional*, defaults to :obj:`None`): The tags in the form of list for the delta model, it is the
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center
center_key_value_tags (:obj:`Dict`, *optional*, defaults to :obj:`None`): The tags in the form of dictionary for the delta model, it is the
dict_tags (:obj:`Dict`, *optional*, defaults to :obj:`None`): The tags in the form of dictionary for the delta model, it is the
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center.
delay_push (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether or not to delay the push to the DeltaCenter. When set to True,
the delta object will be saved locally to save_directory, you can push it later using
@ -142,6 +146,7 @@ class SaveLoadMixin(PushToHubMixin):
"""
save_directory = finetuned_delta_path
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@ -169,26 +174,36 @@ class SaveLoadMixin(PushToHubMixin):
final_center_args = self.create_delta_center_args(center_args=center_args,
center_args_pool=center_args_pool)
print("final_center_args", final_center_args)
if push_to_dc:
self.create_yml(save_directory, final_center_args, center_value_only_tags, center_key_value_tags)
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
if not delay_push:
OssClient.upload(base_dir=save_directory)
else:
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t\t{os.path.abspath(save_directory)}\
\nyou can push it to the delta center later using \n\t\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
+"*"*30)
# get absolute path of saved_directory,
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
f = open("{}/config.yml".format(save_dir), 'w')
yaml.safe_dump(vars(config), f)
config_dict = vars(config)
config_dict['dict_tags'] = dict_tags if dict_tags is not None else {}
config_dict['list_tags'] = list_tags if list_tags is not None else []
yaml.safe_dump(config_dict, f)
f.close()
@classmethod
def from_finetuned(cls,
finetuned_model_name_or_path: Optional[Union[str, os.PathLike]],
finetuned_delta_path: Optional[Union[str, os.PathLike]],
backbone_model: nn.Module,
*model_args,
check_hash: Optional[bool] = True,
**kwargs):
delta_config = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
*model_args,
check_hash: Optional[bool] = True,
**kwargs):
r"""
Instantiate a finetuned delta model from a path.
The backbone_model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
@ -196,7 +211,7 @@ class SaveLoadMixin(PushToHubMixin):
Parameters:
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
@ -297,9 +312,9 @@ class SaveLoadMixin(PushToHubMixin):
"""
config = kwargs.pop("config", None)
# config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
# cache_dir = kwargs.pop("cache_dir", None)
# ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
@ -323,10 +338,10 @@ class SaveLoadMixin(PushToHubMixin):
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, BaseDeltaConfig):
config_path = config if config is not None else finetuned_model_name_or_path
config, model_kwargs = cls.config_class.from_finetuned(
config_path,
if not isinstance(delta_config, BaseDeltaConfig):
# config_path = delta_config if delta_config is not None else finetuned_model_path # Todo check
delta_config, model_kwargs = cls.config_class.from_finetuned(
finetuned_model_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
@ -343,23 +358,24 @@ class SaveLoadMixin(PushToHubMixin):
else:
model_kwargs = kwargs
print("delta_config", delta_config)
# Load model
if finetuned_model_name_or_path is not None:
finetuned_model_name_or_path = str(finetuned_model_name_or_path)
if os.path.isdir(finetuned_model_name_or_path):
if os.path.isfile(os.path.join(finetuned_model_name_or_path, WEIGHTS_NAME)):
if finetuned_model_path is not None:
finetuned_model_path = str(finetuned_model_path)
if os.path.isdir(finetuned_model_path):
if os.path.isfile(os.path.join(finetuned_model_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(finetuned_model_name_or_path, WEIGHTS_NAME)
archive_file = os.path.join(finetuned_model_path, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in "
f"directory {finetuned_model_name_or_path}."
f"directory {finetuned_model_path}."
)
elif os.path.isfile(finetuned_model_name_or_path) or is_remote_url(finetuned_model_name_or_path):
archive_file = finetuned_model_name_or_path
elif os.path.isfile(finetuned_model_path) or is_remote_url(finetuned_model_path):
archive_file = finetuned_model_path
else:
archive_file = hf_bucket_url(
finetuned_model_name_or_path,
finetuned_model_path,
filename=WEIGHTS_NAME,
revision=revision,
mirror=mirror,
@ -381,7 +397,7 @@ class SaveLoadMixin(PushToHubMixin):
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load weights for '{finetuned_model_name_or_path}'. Make sure that:\n\n"
f"Can't load weights for '{finetuned_model_path}'. Make sure that:\n\n"
)
if revision is not None:
@ -414,7 +430,7 @@ class SaveLoadMixin(PushToHubMixin):
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{finetuned_model_name_or_path}' "
f"Unable to load weights from pytorch checkpoint file for '{finetuned_model_path}' "
f"at '{resolved_archive_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
@ -469,7 +485,6 @@ class SaveLoadMixin(PushToHubMixin):
mdict = {}
field = fields(DeltaCenterArguments)
print("center_args_pool", center_args_pool)
for f in field:
exist = False

View File

@ -4,10 +4,10 @@ from collections import namedtuple
def signature(f):
r"""Get the function f 's input arguments. A useful gadget
when some function slot might be instantiated into multiple functions.
Args:
f (:obj:`function`) : the function to get the input arguments.
Returns:
namedtuple : of args, default, varargs, keywords, respectively.s
@ -34,7 +34,7 @@ def signature(f):
] or None
argspec = namedtuple('Signature', ['args', 'defaults',
'varargs', 'keywords'])
return argspec(args, defaults, varargs, keywords)
return argspec(args, defaults, varargs, keywords)
def get_arg_names(f):
r""" Get a functions argument name, remove the ``self`` argument
@ -45,6 +45,7 @@ def get_arg_names(f):
return args
def get_arg_names_inside_func(func):
r""" Get the functions argument name inside the function itself. Remove ``self`` argument.
"""