finish delta center development
This commit is contained in:
parent
617955e08e
commit
0326d5fb7a
|
@ -38,4 +38,5 @@ log.txt
|
|||
**/output/
|
||||
**/thunlp/
|
||||
**/saved_ckpts/
|
||||
|
||||
**/delta_checkpoints/
|
||||
**/outputs/
|
||||
|
|
|
@ -72,6 +72,11 @@ python setup.py install
|
|||
python setup.py develop
|
||||
```
|
||||
|
||||
If you encounter network error using setup.py, please firstly install the dependencies via
|
||||
```shell
|
||||
pip install -r requirements.txt && python setup.py develop
|
||||
```
|
||||
|
||||
## Must Try
|
||||
|
||||
```python
|
||||
|
|
|
@ -1,24 +1,59 @@
|
|||
# !!!!This example collection is still under develop, please wait for some time to use it.
|
||||
# Examples of using opendelta together with 🤗 transformers.
|
||||
|
||||
## install the repo
|
||||
In this repo, we construct a very general pipeline to train and test a PLM using
|
||||
🤗 transformers.
|
||||
|
||||
The pipeline was constructed together with [openpromptu](https://pypi.org/project/openpromptu/), which is a light and
|
||||
model-agnostic version of [openprompt](https://github.com/thunlp/OpenPrompt).
|
||||
|
||||
## Pool of PLMs
|
||||
We are going to adapt most of the models in 🤗 transformers
|
||||
in the repos. The different pipeline, processing, or configurations are specified
|
||||
in `./backbones/`. You can add your own model in this file to support customized models.
|
||||
|
||||
|
||||
### A example script to run the repo in offline mode
|
||||
```bash
|
||||
cd ../
|
||||
python setup_seq2seq.py develop
|
||||
conda activate [YOURENV]
|
||||
PATHBASE=[YOURPATH]
|
||||
|
||||
JOBNAME="adapter_t5-base"
|
||||
DATASET="superglue-cb"
|
||||
|
||||
cd $PATHBASE/OpenDelta/examples/examples_prompt/
|
||||
python configs/gen_t5.py --job $JOBNAME
|
||||
|
||||
export TRANSFORMERS_OFFLINE=1
|
||||
export HF_DATASETS_OFFLINE=1
|
||||
python src/run.py configs/$JOBNAME/$DATASET.json \
|
||||
--model_name_or_path [YOURPATH_TO_T5_BASE] \
|
||||
--tokenizer_name [YOURPATH_TO_T5_BASE] \
|
||||
--datasets_saved_path [YOURPATH_TO_CB_DATASETS] \
|
||||
--finetuned_delta_path ${PATHBASE}/delta_checkpoints/ \
|
||||
--num_train_epochs 20 \
|
||||
--bottleneck_dim 24 \
|
||||
--delay_push True
|
||||
```
|
||||
This will add `examples_seq2seq` to the environment path of the python lib.
|
||||
|
||||
## Generating the json configuration file
|
||||
## A example of quick testing the repo.
|
||||
|
||||
```shell
|
||||
python configs/gen_$BACKBONETYPE.py --job $YOURJOB
|
||||
#e.g. python configs/gen_beit.py --job lora_beit-base-patch16-224
|
||||
```
|
||||
The available job configuration (e.g., `--job lora_beit-base-patch16-224`) can be seen from the scripts. You can also
|
||||
create your only configuration.
|
||||
```bash
|
||||
conda activate [YOURENV]
|
||||
PATHBASE=[YOURPATH]
|
||||
|
||||
JOBNAME="adapter_t5-base"
|
||||
DATASET="superglue-cb"
|
||||
|
||||
## Run the code
|
||||
cd $PATHBASE/OpenDelta/examples/examples_prompt/
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=1 python src/run.py configs/lora_beit-base-patch16-224/beans.json
|
||||
```
|
||||
export TRANSFORMERS_OFFLINE=1
|
||||
export HF_DATASETS_OFFLINE=1
|
||||
export DELTACENTER_OFFLINE=0
|
||||
python src/test.py configs/$JOBNAME/$DATASET.json \
|
||||
--model_name_or_path [YOURPATH_TO_T5_BASE] \
|
||||
--tokenizer_name [YOURPATH_TO_T5_BASE] \
|
||||
--datasets_saved_path [YOURPATH_TO_CB_DATASETS] \
|
||||
--finetuned_delta_path thunlp/t5-base_adapter_superglue-cb_20220701171436c80 \
|
||||
--delta_cache_dir "./delta_checkpoints/" \
|
||||
--force_download True
|
||||
```
|
|
@ -1,59 +0,0 @@
|
|||
# the final results will be populated here.{
|
||||
"evaluate": {
|
||||
"epoch": 20.0,
|
||||
"eval_accuracy": 89.2156862745098,
|
||||
"eval_average_metrics": 90.76168929110105,
|
||||
"eval_f1": 92.3076923076923,
|
||||
"eval_loss": 0.16493959724903107,
|
||||
"eval_runtime": 1.6391,
|
||||
"eval_samples_per_second": 124.455
|
||||
},
|
||||
"repo_name": "DeltaHub/bitfit_t5-base_mrpc",
|
||||
"test": {
|
||||
"epoch": 20.0,
|
||||
"test_accuracy": 88.23529411764706,
|
||||
"test_average_metrics": 89.97971602434077,
|
||||
"test_f1": 91.72413793103448,
|
||||
"test_loss": 0.14968213438987732,
|
||||
"test_runtime": 1.6344,
|
||||
"test_samples_per_second": 124.82
|
||||
}
|
||||
}
|
||||
{
|
||||
"evaluate": {
|
||||
"epoch": 20.0,
|
||||
"eval_average_metrics": 52.10265668831534,
|
||||
"eval_loss": 0.3603779077529907,
|
||||
"eval_matthews_correlation": 52.10265668831534,
|
||||
"eval_runtime": 1.0808,
|
||||
"eval_samples_per_second": 482.046
|
||||
},
|
||||
"repo_name": "DeltaHub/bitfit_t5-base_cola",
|
||||
"test": {
|
||||
"epoch": 20.0,
|
||||
"test_average_metrics": 54.209563471221934,
|
||||
"test_loss": 0.2853100299835205,
|
||||
"test_matthews_correlation": 54.209563471221934,
|
||||
"test_runtime": 1.056,
|
||||
"test_samples_per_second": 494.304
|
||||
}
|
||||
}
|
||||
{
|
||||
"evaluate": {
|
||||
"epoch": 20.0,
|
||||
"eval_average_metrics": 53.80613287067274,
|
||||
"eval_loss": 0.25723716616630554,
|
||||
"eval_matthews_correlation": 53.80613287067274,
|
||||
"eval_runtime": 1.0583,
|
||||
"eval_samples_per_second": 492.299
|
||||
},
|
||||
"repo_name": "DeltaHub/bitfit_t5-base_cola",
|
||||
"test": {
|
||||
"epoch": 20.0,
|
||||
"test_average_metrics": 54.32497579543861,
|
||||
"test_loss": 0.22327613830566406,
|
||||
"test_matthews_correlation": 54.32497579543861,
|
||||
"test_runtime": 1.0556,
|
||||
"test_samples_per_second": 494.507
|
||||
}
|
||||
}
|
|
@ -13,12 +13,11 @@ import numpy as np
|
|||
import torch
|
||||
import re
|
||||
import itertools
|
||||
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
from transformers.models.auto.tokenization_auto import tokenizer_class_from_name
|
||||
|
||||
from typing import List, Dict
|
||||
|
@ -63,7 +62,8 @@ class COLA(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.cola")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'cola',
|
||||
|
@ -91,7 +91,8 @@ class SST2(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.sst2")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'sst2',
|
||||
|
@ -118,10 +119,9 @@ class MRPC(AbstractTask):
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mrpc")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
|
||||
|
@ -147,7 +147,8 @@ class QQP(AbstractTask):
|
|||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qqp")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'qqp',
|
||||
|
@ -203,7 +204,8 @@ class MNLI(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mnli")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")
|
||||
|
@ -238,7 +240,8 @@ class QNLI(AbstractTask):
|
|||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qnli")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
|
||||
|
@ -274,7 +277,8 @@ class RTE(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.rte")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'rte',
|
||||
|
@ -301,7 +305,8 @@ class WNLI(AbstractTask):
|
|||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.wnli")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")
|
||||
|
@ -329,7 +334,8 @@ class SuperGLUEBoolQ(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
|
||||
|
@ -342,7 +348,7 @@ class SuperGLUECB(AbstractTask):
|
|||
split_to_data_split = {"train": "train",
|
||||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
|
||||
metric = [metrics.accuracy]
|
||||
metric_names = ["accuracy"]
|
||||
|
||||
verbalizers = {
|
||||
|
@ -356,7 +362,8 @@ class SuperGLUECB(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.cb")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'cb', split=split, script_version="master")
|
||||
|
@ -382,7 +389,8 @@ class SuperGLUECOPA(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.copa")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'copa', split=split, script_version="master")
|
||||
|
@ -411,7 +419,8 @@ class SuperGLUEMultiRC(AbstractTask):
|
|||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.multirc")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'multirc', split=split, script_version="master")
|
||||
|
@ -454,7 +463,8 @@ class SuperGLUEWIC(AbstractTask):
|
|||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.wic")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'wic', split=split, script_version="master")
|
||||
|
@ -544,7 +554,8 @@ class Beans(AbstractTask):
|
|||
|
||||
def load_dataset(self, split):
|
||||
# from IPython import embed; embed(header="beans")
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
offline = os.environ.get("HF_DATASETS_OFFLINE", "0")
|
||||
if offline == '1':
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.wic")[split]
|
||||
else:
|
||||
return datasets.load_dataset('beans', split=split, script_version="master")
|
||||
|
|
|
@ -328,7 +328,7 @@ def main():
|
|||
center_args_pool = {**vars(model_args), **vars(data_args), **vars(training_args), **vars(delta_args)},
|
||||
list_tags = ['NLI'],
|
||||
dict_tags = {'purpose':'for testing'},
|
||||
delay_push=True,
|
||||
delay_push=delta_args.delay_push,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -152,9 +152,13 @@ def main():
|
|||
Visualization(model).structure_graph()
|
||||
|
||||
if delta_args.delta_type.lower() != "none":
|
||||
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
||||
# delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
||||
delta_model = AutoDeltaModel.from_finetuned(finetuned_model_path=delta_args.finetuned_model_path, cache_dir="saved_ckpts", backbone_model=model)
|
||||
from opendelta.delta_models.adapter import AdapterConfig, AdapterModel
|
||||
delta_config = AdapterConfig.from_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path)
|
||||
delta_model = AdapterModel.from_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path,
|
||||
delta_config=delta_config,
|
||||
backbone_model=model,
|
||||
force_download=delta_args.force_download,
|
||||
cache_dir=delta_args.delta_cache_dir)
|
||||
# delta_model.freeze_module(set_state_dict = True)
|
||||
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
||||
|
||||
|
|
|
@ -243,6 +243,21 @@ class DeltaArguments:
|
|||
finetuned_delta_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "the path of the finetuned delta model"}
|
||||
)
|
||||
force_download: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to download the checkpoint form delta center no matter whether it exists"}
|
||||
)
|
||||
local_files_only: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether not to look for file in delta center"}
|
||||
)
|
||||
delta_cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "The cache path defined by user. If not set, we will firstly look into the"+
|
||||
" working directory and then into the default cache path (ususally ~/.cache/delta_center)."}
|
||||
)
|
||||
delay_push: Optional[bool] = field(
|
||||
default=True, metadata={
|
||||
'help':'whether push the checkpoint to delta center later.'
|
||||
}
|
||||
)
|
||||
|
||||
def merge_arguments(self, objb):
|
||||
print(objb)
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
"""Install Compacter."""
|
||||
import os
|
||||
import setuptools
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
#os.environ['TORCH_CUDA_ARCH_LIST']="3.5;3.7;6.1;7.0;7.5;8.6+PTX"
|
||||
|
||||
def setup_package():
|
||||
long_description = "examples_prompt"
|
||||
setuptools.setup(
|
||||
name='examples_prompt',
|
||||
version='0.0.1',
|
||||
description='textual prompt example',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
author='Shengding Hu',
|
||||
license='MIT License',
|
||||
packages=setuptools.find_packages(
|
||||
exclude=['docs', 'tests', 'scripts']),
|
||||
dependency_links=[
|
||||
'https://download.pytorch.org/whl/torch_stable.html',
|
||||
],
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.7.10',
|
||||
],
|
||||
keywords='text nlp machinelearning',
|
||||
# ext_modules=[
|
||||
# CUDAExtension('seq2seq.projections.fwh_cuda',
|
||||
# sources=[
|
||||
# 'seq2seq/projections/fwh_cuda/fwh_cpp.cpp',
|
||||
# 'seq2seq/projections/fwh_cuda/fwh_cu.cu',
|
||||
# ]
|
||||
# )
|
||||
# ]
|
||||
# ,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
install_requires=[
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_package()
|
|
@ -2,11 +2,9 @@ from copy import deepcopy
|
|||
from typing import Any, Dict, OrderedDict
|
||||
from opendelta.utils.visualization import Visualization
|
||||
import torch.nn as nn
|
||||
from transformers.file_utils import PushToHubMixin
|
||||
from opendelta.utils.logging import get_logger
|
||||
import importlib
|
||||
from opendelta.delta_configs import BaseDeltaConfig
|
||||
from opendelta.basemodel import DeltaBase
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -114,7 +112,7 @@ class AutoDeltaConfig:
|
|||
|
||||
|
||||
@classmethod
|
||||
def from_finetuned(cls, finetuned_model_path, **kwargs):
|
||||
def from_finetuned(cls, finetuned_delta_path, **kwargs):
|
||||
r"""
|
||||
Instantiate one of the configuration classes of the library from a finetuned delta model configuration.
|
||||
The configuration class to instantiate is selected based on the ``delta_type`` property of the config object that
|
||||
|
@ -122,7 +120,7 @@ class AutoDeltaConfig:
|
|||
|
||||
Parameters:
|
||||
|
||||
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||
finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
||||
|
@ -173,20 +171,19 @@ class AutoDeltaConfig:
|
|||
|
||||
"""
|
||||
|
||||
kwargs["name_or_path"] = finetuned_model_path
|
||||
|
||||
config_dict, _ = BaseDeltaConfig.get_config_dict(finetuned_model_path, **kwargs)
|
||||
config_dict, kwargs = BaseDeltaConfig.get_config_dict(finetuned_delta_path, **kwargs)
|
||||
if "delta_type" in config_dict:
|
||||
config_class = LAZY_CONFIG_MAPPING[config_dict["delta_type"]]
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
for pattern, config_class in LAZY_CONFIG_MAPPING.items():
|
||||
if pattern in str(finetuned_model):
|
||||
if pattern in str(finetuned_delta_path):
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized model in {finetuned_model_path}. "
|
||||
f"Unrecognized model in {finetuned_delta_path}. "
|
||||
f"Should have a `delta_type` key in the loaded config, or contain one of the following strings "
|
||||
f"in its name: {', '.join(LAZY_CONFIG_MAPPING.keys())}"
|
||||
)
|
||||
|
@ -355,14 +352,14 @@ class AutoDeltaModel:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_finetuned(cls, finetuned_model_path, backbone_model, *model_args, **kwargs):
|
||||
def from_finetuned(cls, finetuned_delta_path, backbone_model, *model_args, **kwargs):
|
||||
r""" Automatically instantiated a delta model and load the finetuned checkpoints based on the
|
||||
:obj:`finetuned_model_path`, which can either be a string pointing to a local path or a url pointint to
|
||||
:obj:`finetuned_delta_path`, which can either be a string pointing to a local path or a url pointint to
|
||||
the delta hub. It will check the hash after loading the delta model to see whether the correct backbone and
|
||||
delta checkpoint are used.
|
||||
|
||||
Args:
|
||||
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||
finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
||||
|
@ -377,6 +374,7 @@ class AutoDeltaModel:
|
|||
|
||||
backbone_model (:obj:`nn.Module`): The backbone model to be modified.
|
||||
model_args: Other argument for initialize the model.
|
||||
kwargs: Other kwargs that will be passed into DeltaBase.from_finetuned.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -385,15 +383,15 @@ class AutoDeltaModel:
|
|||
delta_model = AutoDeltaModel.from_finetuned("DeltaHub/lora_t5-base-mrpc", backbone_model)
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
delta_config = kwargs.pop("delta_config", None)
|
||||
|
||||
if not isinstance(config, BaseDeltaConfig):
|
||||
config, kwargs = AutoDeltaConfig.from_finetuned(
|
||||
finetuned_model_path, return_unused_kwargs=True, **kwargs
|
||||
if not isinstance(delta_config, BaseDeltaConfig):
|
||||
delta_config, kwargs = AutoDeltaConfig.from_finetuned(
|
||||
finetuned_delta_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
if type(config) in cls._delta_model_mapping.keys():
|
||||
model_class = cls._delta_model_mapping[type(config)]
|
||||
return model_class.from_finetuned(finetuned_model_path, backbone_model, *model_args, **kwargs)
|
||||
if type(delta_config) in cls._delta_model_mapping.keys():
|
||||
model_class = cls._delta_model_mapping[type(delta_config)]
|
||||
return model_class.from_finetuned(finetuned_delta_path, backbone_model, *model_args, delta_config=delta_config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
|
|
|
@ -5,15 +5,6 @@ from opendelta import __version__ as opendelta_version
|
|||
from opendelta.utils import logging
|
||||
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||
import transformers
|
||||
from transformers.file_utils import (
|
||||
PushToHubMixin,
|
||||
is_offline_mode,
|
||||
cached_path,
|
||||
is_remote_url,
|
||||
get_list_of_files,
|
||||
hf_bucket_url,
|
||||
)
|
||||
from packaging import version
|
||||
import json
|
||||
import copy
|
||||
|
||||
|
@ -26,7 +17,7 @@ logger = logging.get_logger(__name__)
|
|||
FULL_CONFIGURATION_FILE = "config.json"
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
class BaseDeltaConfig(PushToHubMixin):
|
||||
class BaseDeltaConfig:
|
||||
r"""Base class for all configuration classes. Handles a few
|
||||
parameters common to all delta models' configurations as well as methods for loading/downloading/saving configurations.
|
||||
|
||||
|
@ -108,7 +99,7 @@ class BaseDeltaConfig(PushToHubMixin):
|
|||
|
||||
|
||||
@classmethod
|
||||
def from_finetuned(cls, finetuned_model_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
|
||||
def from_finetuned(cls, finetuned_delta_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
|
||||
r"""
|
||||
Instantiate a :obj:`BaseDeltaConfig` (or a derived class) from a finetined delta module configuration.
|
||||
|
||||
|
@ -132,7 +123,7 @@ class BaseDeltaConfig(PushToHubMixin):
|
|||
delta_config = LoraConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
|
||||
|
||||
"""
|
||||
config_dict, kwargs = cls.get_config_dict(finetuned_model_path, **kwargs)
|
||||
config_dict, kwargs = cls.get_config_dict(finetuned_delta_path, **kwargs)
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warn(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
|
@ -215,7 +206,7 @@ class BaseDeltaConfig(PushToHubMixin):
|
|||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
logger.info(f"Model config {config}")
|
||||
logger.info(f"Model config\n{config}")
|
||||
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
|
@ -224,101 +215,58 @@ class BaseDeltaConfig(PushToHubMixin):
|
|||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
cls, finetuned_delta_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""[NODOC]
|
||||
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
From a ``finetuned_delta_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
[``PretrainedConfig``] using ``from_dict``.
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
Returns:
|
||||
:obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
cache_dir = kwargs.get("cache_dir", None)
|
||||
force_download = kwargs.get("force_download", False)
|
||||
# resume_download = kwargs.pop("resume_download", False)
|
||||
# proxies = kwargs.pop("proxies", None)
|
||||
# use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
# revision = kwargs.pop("revision", None)
|
||||
# from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
# from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
||||
# user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
||||
# if from_pipeline is not None:
|
||||
# user_agent["using_pipeline"] = from_pipeline
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
if os.environ.get("DELTACENTER_OFFLINE", '0') == '1':
|
||||
logger.info("Delta Center offline mode!")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
finetuned_delta_path = str(finetuned_delta_path)
|
||||
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
if cache_dir is not None:
|
||||
cached_finetuned_delta_path = os.path.join(cache_dir, finetuned_delta_path)
|
||||
else:
|
||||
# configuration_file = get_configuration_file(
|
||||
# pretrained_model_name_or_path,
|
||||
# revision=revision,
|
||||
# use_auth_token=use_auth_token,
|
||||
# local_files_only=local_files_only,
|
||||
# )
|
||||
print("cache_dir", cache_dir, "|||", "pretrained_model_name_or_path", pretrained_model_name_or_path)
|
||||
cached_finetuned_delta_path = finetuned_delta_path
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
elif os.path.isdir(os.path.join(cache_dir, pretrained_model_name_or_path)):
|
||||
config_file = os.path.join(cache_dir, pretrained_model_name_or_path, "config.json")
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||
)
|
||||
if os.path.isfile(cached_finetuned_delta_path):
|
||||
local_files_only = True
|
||||
elif os.path.isdir(cached_finetuned_delta_path):
|
||||
# cached_finetuned_delta_path = os.path.join(cached_finetuned_delta_path, 'config.json')
|
||||
local_files_only = True
|
||||
|
||||
print("config file!!", config_file)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
||||
)
|
||||
|
||||
if revision is not None:
|
||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
msg = (
|
||||
f"Couldn't reach server at '{config_file}' to download configuration file or "
|
||||
"configuration file is not a valid JSON file. "
|
||||
f"Please check network or file content here: {resolved_config_file}."
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info(f"loading configuration file {config_file}")
|
||||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
# if local_files_only:
|
||||
# config_dict = cls._dict_from_json_file(cached_finetuned_delta_path)
|
||||
if not local_files_only or force_download:
|
||||
from .utils.delta_center import download as dcdownload
|
||||
# try to download from DeltaCenter
|
||||
cached_finetuned_delta_path = dcdownload(finetuned_delta_path, force_download=force_download, cache_dir=cache_dir)
|
||||
kwargs['force_download'] = False # Has been downloaded, not more forcing
|
||||
|
||||
cached_finetuned_delta_path = os.path.join(cached_finetuned_delta_path, 'config.json')
|
||||
config_dict = cls._dict_from_json_file(cached_finetuned_delta_path)
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
|
@ -432,53 +380,6 @@ class BaseDeltaConfig(PushToHubMixin):
|
|||
|
||||
|
||||
|
||||
def get_configuration_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
local_files_only: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Get the configuration file to use for this version of transformers.
|
||||
Args:
|
||||
path_or_repo (`:obj:str` or `:obj:os.PathLike`):
|
||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
||||
revision(`:obj:str`, *optional*, defaults to ``"main"``):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
use_auth_token (:obj:`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated
|
||||
when running ``transformers-cli login`` (stored in ``~/.huggingface``).
|
||||
local_files_only (:obj:`bool`, *optional*, defaults to :obj:`False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
Returns:
|
||||
:obj:`str`: The configuration file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
configuration_files_map = {}
|
||||
for file_name in all_files:
|
||||
search = _re_configuration_file.search(file_name)
|
||||
if search is not None:
|
||||
v = search.groups()[0]
|
||||
configuration_files_map[v] = os.path.split(file_name)[-1]
|
||||
available_versions = sorted(configuration_files_map.keys())
|
||||
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
||||
configuration_file = FULL_CONFIGURATION_FILE
|
||||
# transformers_version_ = version.parse(transformers_version)
|
||||
for v in available_versions:
|
||||
# if version.parse(v) <= transformers_version_:
|
||||
configuration_file = configuration_files_map[v]
|
||||
# else:
|
||||
# # No point going further since the versions are sorted.
|
||||
# break
|
||||
|
||||
return configuration_file
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
myconfig = BaseDeltaConfig.from_pretrained("../ckpts/lora/")
|
||||
myconfig.save_pretrained("../ckpts/lora.1/")
|
||||
|
|
|
@ -1,103 +1,10 @@
|
|||
|
||||
# from dataclasses import dataclass, field, fields
|
||||
# from tkinter.messagebox import NO
|
||||
# from typing import Optional, List, Union
|
||||
# from xml.sax import default_parser_list
|
||||
# from opendelta.utils.logging import get_logger
|
||||
|
||||
# logger = get_logger(__name__)
|
||||
from DeltaCenter import OssClient
|
||||
from .file_utils import default_cache_path
|
||||
|
||||
|
||||
# alternative_names = {
|
||||
# "train_tasks": ["train_tasks", "train_task", "task_name"],
|
||||
# }
|
||||
def download(finetuned_delta_path, cache_dir=None, force_download=False):
|
||||
if cache_dir is None:
|
||||
cache_dir = default_cache_path
|
||||
path_to_unzip_file = OssClient.download(finetuned_delta_path, dest=cache_dir, force_download=force_download)
|
||||
return path_to_unzip_file
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class DeltaCenterArguments:
|
||||
# """
|
||||
# The arguments that are used to distinguish between different delta models on the DeltaCenter
|
||||
# """
|
||||
# name: str = field(default="",
|
||||
# metadata={"help": "The name of the delta model checkpoint"}
|
||||
# )
|
||||
# backbone_model: str = field(default="",
|
||||
# metadata={"help": "The backbone model of the delta model"}
|
||||
# )
|
||||
# model_name_or_path: str = field(
|
||||
# default = None,
|
||||
# metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
# )
|
||||
# model_revision: str = field(
|
||||
# default="main",
|
||||
# metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
# )
|
||||
# delta_type: str = field(
|
||||
# default=None,
|
||||
# metadata={"help": "the type of type model, e.g., adapter, lora, etc."}
|
||||
# )
|
||||
# train_tasks: Optional[Union[List[str], str]]= field(
|
||||
# default=None,
|
||||
# metadata={"help": "the task(s) that the delta is trained on"}
|
||||
# )
|
||||
# checkpoint_size: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={"help": "the size of the checkpoint, in MB"}
|
||||
# )
|
||||
# test_tasks: Optional[Union[List[str], str]] = field(
|
||||
# default=None,
|
||||
# metadata={"help": "the task(s) that the delta is tested on"}
|
||||
# )
|
||||
# test_performance: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={"help": "the performance of the model on the test set"}
|
||||
# )
|
||||
# trainable_ratio: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={"help": "the ratio of trainable parameters in the model"}
|
||||
# )
|
||||
# delta_ratio: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={"help": "the ratio of delta parameters in the model"}
|
||||
# )
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def create_repo_name(prefix="", center_args=None, **kwargs
|
||||
# ):
|
||||
# r"""Currently, it's only a simple concatenation of the arguments.
|
||||
# """
|
||||
# if center_args is None:
|
||||
# center_args = create_delta_center_args(**kwargs)
|
||||
# reponame = prefix+"_"
|
||||
# reponame += center_args.model_name_or_path.split()[-1]+"_" if center_args.model_name_or_path is not None else ""
|
||||
# reponame += center_args.delta_type+"_" if center_args.delta_type is not None else ""
|
||||
|
||||
# # tasks
|
||||
|
||||
# if isinstance(center_args.train_tasks, list):
|
||||
# train_tasks = "+".join(center_args.train_tasks)
|
||||
# elif center_args.train_tasks is not None:
|
||||
# train_tasks = center_args.train_tasks
|
||||
# else:
|
||||
# logger.warning("train_tasks are not find in all arguments. Do you miss them?")
|
||||
# train_tasks = None
|
||||
# reponame += train_tasks+"_" if train_tasks is not None else ""
|
||||
# reponame = reponame.strip("_")
|
||||
# return reponame
|
||||
|
||||
# def create_delta_center_args(**kwargs):
|
||||
# mdict = {}
|
||||
# field = fields(DeltaCenterArguments)
|
||||
# for f in field:
|
||||
# if f.name in kwargs:
|
||||
# mdict[f.name] = kwargs[f.name]
|
||||
# else:
|
||||
# for altername in alternative_names[f.name]:
|
||||
# if altername in kwargs:
|
||||
# mdict[f.name] = kwargs[altername]
|
||||
# break
|
||||
# center_args = DeltaCenterArguments(**mdict)
|
||||
# return center_args
|
|
@ -1,24 +0,0 @@
|
|||
|
||||
|
||||
def create_hub_repo_name(root = "DeltaHub",
|
||||
dataset = None,
|
||||
delta_type = None,
|
||||
model_name_or_path = None,
|
||||
):
|
||||
r"""Currently, it's only a simple concatenation of the arguments.
|
||||
"""
|
||||
repo_name = []
|
||||
|
||||
repo_name.append(f"{delta_type}")
|
||||
model_name_or_path = model_name_or_path.split("/")[-1]
|
||||
repo_name.append(f"{model_name_or_path}")
|
||||
repo_name.append(f"{dataset}")
|
||||
|
||||
repo_name = "_".join(repo_name)
|
||||
|
||||
repo_name = root+"/"+repo_name
|
||||
return repo_name
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
import os
|
||||
default_cache_path = "{}/.cache/delta_center/".format(os.path.expanduser('~'))
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
|
@ -1,7 +1,4 @@
|
|||
|
||||
from io import RawIOBase
|
||||
import re
|
||||
from tarfile import HeaderError
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
from opendelta.delta_configs import BaseDeltaConfig
|
||||
from opendelta.utils.model_md5 import gen_model_hash, gen_parameter_hash
|
||||
|
@ -9,24 +6,16 @@ import torch
|
|||
import os
|
||||
from opendelta import logging
|
||||
import torch.nn as nn
|
||||
from transformers.file_utils import (
|
||||
WEIGHTS_NAME,
|
||||
PushToHubMixin,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
hf_bucket_url,
|
||||
cached_path,
|
||||
)
|
||||
from transformers.utils.dummy_pt_objects import PreTrainedModel
|
||||
import hashlib
|
||||
from DeltaCenter import OssClient
|
||||
import yaml
|
||||
from dataclasses import dataclass, field, fields
|
||||
import datetime
|
||||
from .file_utils import WEIGHTS_NAME
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
alternative_names = {
|
||||
"train_tasks": ["train_tasks", "train_task", "task_name"],
|
||||
}
|
||||
|
@ -86,7 +75,7 @@ class DeltaCenterArguments:
|
|||
|
||||
|
||||
|
||||
class SaveLoadMixin(PushToHubMixin):
|
||||
class SaveLoadMixin:
|
||||
def add_configs_when_saving(self,):
|
||||
self.config.backbone_class = self.backbone_model.__class__.__name__
|
||||
self.config.backbone_checkpoint_name = os.path.split(self.backbone_model.config._name_or_path.strip("/"))[-1]
|
||||
|
@ -110,48 +99,49 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
):
|
||||
r"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
:py:meth:`~DeltaBase.from_finetuned` class method.
|
||||
:py:meth:`~DeltaBase.save_finetuned` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
save_config (:obj:`bool`, *optional*, defaults to :obj:`True`):
|
||||
Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
|
||||
to call this function on all processes. In this case, set ``save_config=True`` only on the main process
|
||||
to avoid race conditions.
|
||||
state_dict (nested dictionary of :obj:`torch.Tensor`):
|
||||
The state dictionary of the model to save. Will default to ``self.state_dict()``, but can be used to only
|
||||
save parts of the model or if special precautions need to be taken when recovering the state dictionary
|
||||
of a model (like when using model parallelism).
|
||||
save_function (:obj:`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace ``torch.save`` by another method.
|
||||
push_to_dc (:obj:`bool`, *optional*, defaults to :obj:`True`): Whether or not to push the model to the DeltaCenter.
|
||||
center_args (:obj:`Union[DeltaCenterArguments, dict]`, *optional*, defaults to :obj:`None`): The arguments
|
||||
that are used to distinguish between different delta models on the DeltaCenter. It has higher priority than the `center_args_pool`.
|
||||
It will be used to group delta models.
|
||||
center_args_pool (:obj:`dict`, *optional*, defaults to :obj:`None`): The arguments's pool for DeltaCenter
|
||||
Together with center_args, they are are used to distinguish between different delta models on the DeltaCenter.
|
||||
It will be used to group delta models.
|
||||
list_tags (:obj:`List`, *optional*, defaults to :obj:`None`): The tags in the form of list for the delta model, it is the
|
||||
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center
|
||||
dict_tags (:obj:`Dict`, *optional*, defaults to :obj:`None`): The tags in the form of dictionary for the delta model, it is the
|
||||
optional identifiers that are not expected by `DeltaCenterArgument`. It will not be used to group delta models in the delta center.
|
||||
delay_push (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether or not to delay the push to the DeltaCenter. When set to True,
|
||||
the delta object will be saved locally to save_directory, you can push it later using
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python -m DeltaCenter upload save_directory
|
||||
|
||||
finetuned_delta_path: (optional) path to the directory where the model and its configuration file will be saved.
|
||||
If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
|
||||
which is a subdirectory of the current working directory.
|
||||
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
|
||||
model file.
|
||||
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
|
||||
state_dict is loaded from the backbone model's trainable parameters.
|
||||
save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
|
||||
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
|
||||
This includes:
|
||||
- creating a configuration file for the model
|
||||
- creating a directory for the model
|
||||
- saving the model's trainable parameters
|
||||
- pushing the model to the DeltaCenter
|
||||
center_args: (optional) the arguments that are used to distinguish between different delta models on the DeltaCenter
|
||||
center_args_pool: (optional) a dictionary containing the arguments that are used to distinguish between different delta models on the DeltaCenter
|
||||
list_tags: (optional) a list of tags that will be added to the model's configuration file
|
||||
dict_tags: (optional) a dictionary of tags that will be added to the model's configuration file
|
||||
delay_push: (optional) if ``True``, the model will not be pushed to the DeltaCenter. This is useful if you want to
|
||||
push the model later.
|
||||
|
||||
"""
|
||||
|
||||
# create the config to save, including model hash, etc.
|
||||
if save_config:
|
||||
if not hasattr(self, "config"):
|
||||
self.create_config_from_model()
|
||||
self.add_configs_when_saving()
|
||||
|
||||
final_center_args = self.create_delta_center_args(center_args=center_args,
|
||||
center_args_pool=center_args_pool)
|
||||
|
||||
save_directory = finetuned_delta_path
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
save_directory = os.path.join(save_directory, final_center_args.name)
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = self.backbone_model# unwrap_model(self)
|
||||
|
||||
|
@ -161,27 +151,24 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
|
||||
# Save the config
|
||||
if save_config:
|
||||
if not hasattr(self, "config"):
|
||||
self.create_config_from_model()
|
||||
self.add_configs_when_saving()
|
||||
self.config.save_finetuned(save_directory)
|
||||
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
save_function(state_dict, output_model_file)
|
||||
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
|
||||
final_center_args = self.create_delta_center_args(center_args=center_args,
|
||||
center_args_pool=center_args_pool)
|
||||
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
|
||||
)
|
||||
|
||||
if push_to_dc:
|
||||
logger.info("Creating yaml file for delta center")
|
||||
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
|
||||
|
||||
|
||||
if not delay_push:
|
||||
OssClient.upload(base_dir=save_directory)
|
||||
else:
|
||||
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t\t{os.path.abspath(save_directory)}\
|
||||
\nyou can push it to the delta center later using \n\t\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
||||
logger.info("Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
||||
+"*"*30)
|
||||
|
||||
# get absolute path of saved_directory,
|
||||
|
@ -201,8 +188,11 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
backbone_model: nn.Module,
|
||||
delta_config = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
state_dict: Optional[dict] = None,
|
||||
*model_args,
|
||||
force_download: Optional[bool] = False,
|
||||
check_hash: Optional[bool] = True,
|
||||
local_files_only: Optional[bool] = False,
|
||||
**kwargs):
|
||||
r"""
|
||||
Instantiate a finetuned delta model from a path.
|
||||
|
@ -210,250 +200,76 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
To further train the model, you can use the :meth:`freeze_module <opendelta.basemodel.DeltaBase.freeze_module>` method.
|
||||
|
||||
Parameters:
|
||||
|
||||
finetuned_model_path (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a
|
||||
user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
:meth:`SaveLoadMixin.save_finetuned`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
this case, ``from_tf`` should be set to ``True`` and a configuration object should be provided as
|
||||
``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
``./flax_model/`` containing ``flax_model.msgpack``). In this case, ``from_flax`` should be set to
|
||||
``True``.
|
||||
- ``None`` if you are both providing the configuration and state dictionary (resp. with keyword
|
||||
arguments ``config`` and ``state_dict``).
|
||||
backbone_model (:obj:`torch.nn.Module`): The backbone model to be modified.
|
||||
model_args (sequence of positional arguments, *optional*):
|
||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
config (Union[:obj:`BaseDeltaConfig`, :obj:`str`, :obj:`os.PathLike`], *optional*): Can be either:
|
||||
- an instance of a class derived from :class:`~PretrainedConfig`,
|
||||
- a string or path valid as input to :py:meth:`~PretrainedConfig.from_pretrained`.
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
||||
model).
|
||||
- The model was saved using :py:meth:`~PreTrainedModel.save_pretrained` and is reloaded by supplying the
|
||||
save directory.
|
||||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||
configuration JSON file named *config.json* is found in the directory.
|
||||
state_dict (Dict[:obj:`str`, :obj:`torch.Tensor`], *optional*):
|
||||
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own
|
||||
weights. In this case though, you should check if using :py:meth:`~PreTrainedModel.save_pretrained` and
|
||||
:py:meth:`~PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir (:obj:`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, *optional*, defaults to :obj:`False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (:obj:`bool`, *optional*, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (:obj:`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only(:obj:`bool`, *optional*, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (:obj:`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated
|
||||
when running ``transformers-cli login`` (stored in ``~/.huggingface``).
|
||||
revision(:obj:`str`, *optional*, defaults to ``"main"``):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
mirror(:obj:`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
torch_dtype (:obj:`str` or :obj:`torch.dtype`, *optional*):
|
||||
Override the default :obj:`torch.dtype` and load the model under this dtype. If ``"auto"`` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
|
||||
.. warning::
|
||||
|
||||
This feature is inherited from HuggingFace. We do not guarantee its usefulness currently.
|
||||
One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
|
||||
4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
|
||||
`pull request 11471 <https://github.com/huggingface/transformers/pull/11471>`_ for more information.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
``output_attentions=True``). Behaves differently depending on whether a ``config`` is provided or
|
||||
automatically loaded:
|
||||
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||
initialization function (:py:meth:`~PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that
|
||||
corresponds to a configuration attribute will be used to override said attribute with the
|
||||
supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute
|
||||
will be passed to the underlying model's ``__init__`` function.
|
||||
|
||||
.. tip::
|
||||
Passing ``use_auth_token=True`` is required when you want to use a private model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
t5 = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
from opendelta import AutoDeltaModel
|
||||
delta = AutoDeltaModel.from_finetuned("DeltaHub/lora_t5-base_mrpc", backbone_model=t5)
|
||||
delta.log()
|
||||
|
||||
|
||||
|
||||
finetuned_delta_path: (optional) path to the directory where the model and its configuration file will be saved.
|
||||
If not specified, the model will be loaded from the directory cahce directory. (see ``cache_dir``),
|
||||
backbone_model: the backbone model that will be used to instantiate the finetuned delta model.
|
||||
delta_config: (optional) the configuration file of the finetuned delta model. If not specified, the configuration file
|
||||
is loaded from the directory ``finetuned_delta_path``.
|
||||
cache_dir: (optional) path to the directory where the model and its configuration file will be saved.
|
||||
If not specified, we will first look into current working directory, then the cache directory of your system, e.g., ~/.cache/delta_center/,
|
||||
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
|
||||
state_dict is loaded from the ``finetuned_delta_path``.
|
||||
force_download: (optional) if ``True``, the model will be downloaded from the internet even if it is already
|
||||
present in the cache directory.
|
||||
check_hash: (optional) if ``True``, check whether the hash of the model once it's trained differs from what we load now.
|
||||
local_files_only: (optional) if ``True``, the model will be loaded from the local cache directory.
|
||||
"""
|
||||
# config = kwargs.pop("config", None)
|
||||
state_dict = kwargs.pop("state_dict", None)
|
||||
# cache_dir = kwargs.pop("cache_dir", None)
|
||||
|
||||
# ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
# output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
# _fast_init = kwargs.pop("_fast_init", True)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
# low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
if os.environ.get("DELTACENTER_OFFLINE", '0') == '1':
|
||||
logger.info("Delta Center offline mode!")
|
||||
local_files_only = True
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
|
||||
|
||||
finetuned_delta_path = str(finetuned_delta_path)
|
||||
|
||||
if cache_dir is not None:
|
||||
cached_finetuned_delta_path = os.path.join(cache_dir, finetuned_delta_path)
|
||||
else:
|
||||
cached_finetuned_delta_path = finetuned_delta_path
|
||||
|
||||
download_from_dc = False
|
||||
if os.path.isfile(cached_finetuned_delta_path):
|
||||
raise RuntimeError(
|
||||
f"You should pass a directory to load a delta checkpoint instead of a file, "
|
||||
f"since we need the delta's configuration file."
|
||||
)
|
||||
elif os.path.isdir(cached_finetuned_delta_path):
|
||||
if os.path.isfile(os.path.join(cached_finetuned_delta_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
weight_file = os.path.join(cached_finetuned_delta_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in "
|
||||
f"directory {cached_finetuned_delta_path}."
|
||||
)
|
||||
|
||||
else:
|
||||
# try to download from DeltaCenter
|
||||
from .delta_center import download as dcdownload
|
||||
cached_finetuned_delta_path = dcdownload(finetuned_delta_path, cache_dir=cache_dir, force_download=force_download)
|
||||
download_from_dc = True
|
||||
weight_file = os.path.join(cached_finetuned_delta_path, WEIGHTS_NAME)
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = torch.load(weight_file, map_location="cpu")
|
||||
|
||||
if not isinstance(delta_config, BaseDeltaConfig):
|
||||
# config_path = delta_config if delta_config is not None else finetuned_model_path # Todo check
|
||||
delta_config, model_kwargs = cls.config_class.from_finetuned(
|
||||
finetuned_model_path,
|
||||
cache_dir=cache_dir,
|
||||
cached_finetuned_delta_path,
|
||||
cache_dir=None,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
local_files_only=True if download_from_dc else local_files_only, # has been downloaded
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
print("delta_config", delta_config)
|
||||
# Load model
|
||||
if finetuned_model_path is not None:
|
||||
finetuned_model_path = str(finetuned_model_path)
|
||||
if os.path.isdir(finetuned_model_path):
|
||||
if os.path.isfile(os.path.join(finetuned_model_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(finetuned_model_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in "
|
||||
f"directory {finetuned_model_path}."
|
||||
)
|
||||
elif os.path.isfile(finetuned_model_path) or is_remote_url(finetuned_model_path):
|
||||
archive_file = finetuned_model_path
|
||||
else:
|
||||
archive_file = hf_bucket_url(
|
||||
finetuned_model_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached #TODO
|
||||
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{finetuned_model_path}'. Make sure that:\n\n"
|
||||
)
|
||||
|
||||
if revision is not None:
|
||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
|
||||
if state_dict is None:
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(resolved_archive_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{finetuned_model_path}' "
|
||||
f"at '{resolved_archive_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
)
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
# weights entry - we assume all weights are of the same dtype
|
||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
dtype_orig = None
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if torch_dtype == "auto":
|
||||
torch_dtype = next(iter(state_dict.values())).dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
||||
)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
|
||||
# Initialize the model from config and attach the delta model to the backbone_model.
|
||||
delta_model = cls.from_config(config, backbone_model, *model_args, **model_kwargs, )
|
||||
delta_model = cls.from_config(delta_config, backbone_model, *model_args, **model_kwargs, )
|
||||
|
||||
# load the state_dict into the backbone_model. As the delta model's parameter
|
||||
# is the same object as the deltas in the backbone model with different reference name,
|
||||
|
@ -461,14 +277,21 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
delta_model._load_state_dict_into_backbone(backbone_model, state_dict)
|
||||
|
||||
backbone_hash = gen_model_hash(backbone_model)
|
||||
if check_hash and hasattr(config, "backbone_hash") and \
|
||||
config.backbone_hash is not None and \
|
||||
config.backbone_hash != backbone_hash:
|
||||
logger.warning("The config has an hash of the backbone model, and is"
|
||||
"different from the hash of the loaded model. This indicates a mismatch"
|
||||
"between the backbone model that the delta checkpoint is based on and"
|
||||
"the one you loaded. You propobability need to Train the model instead of"
|
||||
"directly inference. ")
|
||||
|
||||
if check_hash:
|
||||
if hasattr(delta_config, "backbone_hash") and \
|
||||
delta_config.backbone_hash is not None and \
|
||||
delta_config.backbone_hash != backbone_hash:
|
||||
logger.warning("The config has an hash of the backbone model, and is"
|
||||
"different from the hash of the loaded model. This indicates a mismatch"
|
||||
"between the backbone model that the delta checkpoint is based on and"
|
||||
"the one you loaded. You propobability need to Train the model instead of"
|
||||
"directly inference. ")
|
||||
else:
|
||||
logger.info("Hash-check passed. You can safely use this checkpoint directly.")
|
||||
else:
|
||||
logger.warning("Parameters' hash has not been checked!")
|
||||
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
backbone_model.eval()
|
||||
|
@ -511,7 +334,7 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
|
||||
# if eventualy name is not set, create a default one
|
||||
if mdict['name'] is None or mdict['name'] == '':
|
||||
print("Warning: name is not set, use default name")
|
||||
logger.info("Name is not set, use default name.")
|
||||
mdict['name'] = self.create_default_name(**mdict)
|
||||
|
||||
|
||||
|
@ -521,7 +344,6 @@ class SaveLoadMixin(PushToHubMixin):
|
|||
def create_default_name(self, **kwargs):
|
||||
r"""Currently, it's only a simple concatenation of the arguments.
|
||||
"""
|
||||
print("key args", kwargs)
|
||||
|
||||
reponame = ""
|
||||
reponame += kwargs["model_path_public"].split("/")[-1]+"_" if kwargs['model_path_public'] is not None else kwargs['backbone_model']
|
||||
|
|
Loading…
Reference in New Issue