OpenDeltaMirror/opendelta/utils/saving_loading_utils.py

417 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, List, Union, Optional, Callable
from opendelta.delta_configs import BaseDeltaConfig
from opendelta.utils.model_md5 import gen_model_hash, gen_parameter_hash
import torch
import os
from opendelta import logging
import torch.nn as nn
from DeltaCenter import OssClient
import yaml
from dataclasses import dataclass, field, fields
import datetime
from .file_utils import WEIGHTS_NAME
logger = logging.get_logger(__name__)
alternative_names = {
"train_tasks": ["train_tasks", "train_task", "task_name"],
}
@dataclass
class DeltaCenterArguments:
"""
The arguments that are used to distinguish between different delta models on the DeltaCenter
"""
name: str = field(default="",
metadata={"help": "The name of the delta model checkpoint"}
)
backbone_model: str = field(default="",
metadata={"help": "The backbone model of the delta model"}
)
backbone_model_path_public: str = field(
default = None,
metadata={"help": "Publicly available path (url) to pretrained model or model identifier from huggingface.co/models"}
)
delta_type: str = field(
default=None,
metadata={"help": "the type of type model, e.g., adapter, lora, etc."}
)
train_tasks: Optional[Union[List[str], str]]= field(
default=None,
metadata={"help": "the task(s that the delta is trained on"}
)
train_datasets: Optional[Union[List[str], str]]= field(
default=None,
metadata={"help": "the datasets(s that the delta is trained on"}
)
checkpoint_size: Optional[float] = field(
default=None,
metadata={"help": "the size of the checkpoint, in MB"}
)
test_tasks: Optional[Union[List[str], str]] = field(
default=None,
metadata={"help": "the task(s) that the delta is tested on"}
)
test_datasets: Optional[Union[List[str], str]] = field(
default=None,
metadata={"help": "the dataset(s) that the delta is tested on"}
)
test_performance: Optional[float] = field(
default=None,
metadata={"help": "the performance of the model on the test set"}
)
test_metrics: Optional[str] = field(
default=None,
metadata={"help": "the metrics used by the model"}
)
trainable_ratio: Optional[float] = field(
default=None,
metadata={"help": "the ratio of trainable parameters in the model"}
)
delta_ratio: Optional[float] = field(
default=None,
metadata={"help": "the ratio of delta parameters in the model"}
)
usage: Optional[str] = field(
default="",
metadata={"help": "the usage code of the model"}
)
license: Optional[str] = field(
default="apache-2.0",
metadata={"help": "the license of the model"}
)
class SaveLoadMixin:
def add_configs_when_saving(self,):
self.config.backbone_class = self.backbone_model.__class__.__name__
self.config.backbone_checkpoint_name = os.path.split(self.backbone_model.config._name_or_path.strip("/"))[-1]
self.config.backbone_hash = gen_model_hash(self.backbone_model)
def save_finetuned(
self,
finetuned_delta_path: Optional[Union[str, os.PathLike]] = "./delta_checkpoints/",
save_config: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_dc: bool = True,
center_args: Optional[Union[DeltaCenterArguments, dict]] = dict(),
center_args_pool: Optional[dict] = dict(),
list_tags: Optional[List] = list(),
dict_tags: Optional[Dict] = dict(),
delay_push: bool = False,
test_result = None,
usage: Optional[str] = "",
):
r"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
:py:meth:`~DeltaBase.save_finetuned` class method.
Arguments:
finetuned_delta_path: (optional) path to the directory where the model and its configuration file will be saved.
If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
which is a subdirectory of the current working directory.
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
model file. if ``False``, only the state dict will be saved.
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
state_dict is loaded from the backbone model's trainable parameters.
save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
state_dict_only: (optional) if ``True``, only the state_dict will be saved.
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
This includes:
- creating a configuration file for the model
- creating a directory for the model
- saving the model's trainable parameters
- pushing the model to the DeltaCenter
center_args: (optional) the arguments that are used to distinguish between different delta models on the DeltaCenter
center_args_pool: (optional) a dictionary containing the arguments that are used to distinguish between different delta models on the DeltaCenter
list_tags: (optional) a list of tags that will be added to the model's configuration file
dict_tags: (optional) a dictionary of tags that will be added to the model's configuration file
delay_push: (optional) if ``True``, the model will not be pushed to the DeltaCenter. This is useful if you want to
push the model later.
"""
# create the config to save, including model hash, etc.
if save_config:
if not hasattr(self, "config"):
self.create_config_from_model()
self.add_configs_when_saving()
if push_to_dc:
final_center_args = self.create_delta_center_args(center_args=center_args,
center_args_pool=center_args_pool)
save_directory = finetuned_delta_path
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
if push_to_dc:
save_directory = os.path.join(save_directory, final_center_args.name)
os.makedirs(save_directory, exist_ok=True)
model_to_save = self.backbone_model# unwrap_model(self)
# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
# Save the config
if save_config:
self.config.save_finetuned(save_directory)
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
)
state_dict_total_params = sum(p.numel() for p in state_dict.values())
other_tags={}
other_tags.update({'state_dict_total_params(M)':state_dict_total_params/1024/1024})
other_tags.update({'test_result':test_result})
if push_to_dc:
logger.info("Creating yaml file for delta center")
self.create_yml(save_directory, final_center_args, list_tags, dict_tags, other_tags)
if not delay_push:
OssClient.upload(base_dir=save_directory)
else:
logger.info(f"Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
+"*"*30)
def create_yml(self, save_dir, config, list_tags=list(), dict_tags=dict(),other_tags=None):
f = open("{}/config.yml".format(save_dir), 'w')
config_dict = vars(config)
config_dict['dict_tags'] = dict_tags
config_dict['list_tags'] = list_tags
if other_tags is not None:
config_dict.update(other_tags)
yaml.safe_dump(config_dict, f)
f.close()
def load_checkpoint(self, path, load_func=torch.load, backbone_model=None):
r"""Simple method for loading only the checkpoint
"""
if backbone_model is None:
backbone_model = self.backbone_model
self.backbone_model.load_state_dict(load_func(f"{path}/{WEIGHTS_NAME}"), strict=False)
def save_checkpoint(self, path, save_func=torch.save, backbone_model=None):
r"""Simple method for saving only the checkpoint"""
if backbone_model is None:
backbone_model = self.backbone_model
save_func(backbone_model.state_dict(), f"{path}/{WEIGHTS_NAME}")
@classmethod
def from_finetuned(cls,
finetuned_delta_path: Optional[Union[str, os.PathLike]],
backbone_model: nn.Module,
delta_config = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
state_dict: Optional[dict] = None,
*model_args,
force_download: Optional[bool] = False,
check_hash: Optional[bool] = True,
local_files_only: Optional[bool] = False,
**kwargs):
r"""
Instantiate a finetuned delta model from a path.
The backbone_model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
To further train the model, you can use the :meth:`freeze_module <opendelta.basemodel.DeltaBase.freeze_module>` method.
Parameters:
finetuned_delta_path: (optional) path to the directory where the model and its configuration file will be saved.
If not specified, the model will be loaded from the directory cahce directory. (see ``cache_dir``),
backbone_model: the backbone model that will be used to instantiate the finetuned delta model.
delta_config: (optional) the configuration file of the finetuned delta model. If not specified, the configuration file
is loaded from the directory ``finetuned_delta_path``.
cache_dir: (optional) path to the directory where the model and its configuration file will be saved.
If not specified, we will first look into current working directory, then the cache directory of your system, e.g., ~/.cache/delta_center/,
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
state_dict is loaded from the ``finetuned_delta_path``.
force_download: (optional) if ``True``, the model will be downloaded from the internet even if it is already
present in the cache directory.
check_hash: (optional) if ``True``, check whether the hash of the model once it's trained differs from what we load now.
local_files_only: (optional) if ``True``, the model will be loaded from the local cache directory.
"""
if os.environ.get("DELTACENTER_OFFLINE", '0') == '1':
logger.info("Delta Center offline mode!")
local_files_only = True
# Load config if we don't provide a configuration
finetuned_delta_path = str(finetuned_delta_path)
if cache_dir is not None:
cached_finetuned_delta_path = os.path.join(cache_dir, finetuned_delta_path)
else:
cached_finetuned_delta_path = finetuned_delta_path
download_from_dc = False
if os.path.isfile(cached_finetuned_delta_path):
raise RuntimeError(
f"You should pass a directory to load a delta checkpoint instead of a file, "
f"since we need the delta's configuration file."
)
elif os.path.isdir(cached_finetuned_delta_path):
if os.path.isfile(os.path.join(cached_finetuned_delta_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
weight_file = os.path.join(cached_finetuned_delta_path, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in "
f"directory {cached_finetuned_delta_path}."
)
else:
# try to download from DeltaCenter
from .delta_center import download as dcdownload
cached_finetuned_delta_path = dcdownload(finetuned_delta_path, cache_dir=cache_dir, force_download=force_download)
download_from_dc = True
weight_file = os.path.join(cached_finetuned_delta_path, WEIGHTS_NAME)
if state_dict is None:
state_dict = torch.load(weight_file, map_location="cpu")
if not isinstance(delta_config, BaseDeltaConfig):
delta_config, model_kwargs = cls.config_class.from_finetuned(
cached_finetuned_delta_path,
cache_dir=None,
return_unused_kwargs=True,
local_files_only=True if download_from_dc else local_files_only, # has been downloaded
**kwargs,
)
else:
model_kwargs = kwargs
# Initialize the model from config and attach the delta model to the backbone_model.
delta_model = cls.from_config(delta_config, backbone_model, *model_args, **model_kwargs, )
# load the state_dict into the backbone_model. As the delta model's parameter
# is the same object as the deltas in the backbone model with different reference name,
# the state_dict will also be loaded into the delta model.
delta_model._load_state_dict_into_backbone(backbone_model, state_dict)
backbone_hash = gen_model_hash(backbone_model)
if check_hash:
if hasattr(delta_config, "backbone_hash") and \
delta_config.backbone_hash is not None and \
delta_config.backbone_hash != backbone_hash:
logger.warning("The config has an hash of the backbone model, and is"
"different from the hash of the loaded model. This indicates a mismatch"
"between the backbone model that the delta checkpoint is based on and"
"the one you loaded. You propobability need to Train the model instead of"
"directly inference. ")
else:
logger.info("Hash-check passed. You can safely use this checkpoint directly.")
else:
logger.warning("Parameters' hash has not been checked!")
# Set model in evaluation mode to deactivate DropOut modules by default
backbone_model.eval()
return delta_model
def create_delta_center_args(self, center_args, center_args_pool):
"""
Create the delta center args for the center model.
center_args has higher priority than center_args_pool.
"""
mdict = {}
field = fields(DeltaCenterArguments)
for f in field:
exist = False
# first is center_args, exact match
if f.name in center_args:
mdict[f.name] = center_args[f.name]
continue
# second is center_args_pool, can use alternative names
if f.name in center_args_pool:
mdict[f.name] = center_args_pool[f.name]
exist = True
elif f.name in alternative_names:
for altername in alternative_names[f.name]:
if altername in center_args_pool:
mdict[f.name] = center_args_pool[altername]
exist = True
break
# if not exist, find from self.stat or set to default
if not exist:
if f.name in self.stat:
mdict[f.name] = self.stat[f.name]
else:
mdict[f.name] = f.default
# if eventualy name is not set, create a default one
if mdict['name'] is None or mdict['name'] == '':
logger.info("Name is not set, use default name.")
mdict['name'] = self.create_default_name(**mdict)
if len(mdict['usage']) == 0:
logger.info("Usage is not set, use default usage.")
mdict['usage'] = self.create_default_usage(mdict['name'])
center_args = DeltaCenterArguments(**mdict)
return center_args
def create_default_usage(self, name):
usage_str = """from opendelta import AutoDeltaModel\n""" + \
"""delta_model = AutoDeltaModel.from_finetuned('{name_with_userid}', backbone_model=model)\n""" + \
"""delta_model.freeze_module() # if you are going to further train it \n""" + \
"""delta_model.log()"""
return usage_str
def create_default_name(self, **kwargs):
r"""Currently, it's only a simple concatenation of the arguments.
"""
reponame = ""
reponame += kwargs["backbone_model_path_public"].split("/")[-1]+"_" if kwargs['backbone_model_path_public'] is not None else kwargs['backbone_model']
reponame += kwargs["delta_type"]+"_" if kwargs["delta_type"] is not None else ""
# tasks
if isinstance(kwargs["train_tasks"], list):
train_tasks = "+".join(kwargs["train_tasks"])
elif kwargs["train_tasks"] is not None:
train_tasks = kwargs["train_tasks"]
else:
logger.warning("train_tasks are not find in all arguments. Do you miss them?")
train_tasks = None
reponame += train_tasks+"_" if train_tasks is not None else ""
# time
reponame += datetime.datetime.now().strftime("%Y%m%d%H%M%S") #+ gen_model_hash(model=self.backbone_model)
# model hash
if hasattr(self.config, "backbone_hash"):
reponame += self.config.backbone_hash[:3]
return reponame