diff --git a/.gitignore b/.gitignore
index 7d6dd27..411800c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -54,12 +54,18 @@ t.sh
**/delta_checkpoints/
**/outputs/
+dist/
+dist/*
**/unittest/**
!unittest/**.py
!unittest/**.sh
!unittest/**.md
-dist/
-dist/*
+**/tutorial/**
+!tutorial/**.py
+!tutorial/**.sh
+!tutorial/**.md
+
+
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 1be8e51..8a94518 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -31,8 +31,8 @@ copyright = '{}, {}, Licenced under the Apache License, Version 2.0'.format(date
# The full version, including alpha/beta/rc tags
-release = '0.3.1'
-version = "0.3.1"
+release = '0.3.2'
+version = "0.3.2"
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
diff --git a/docs/source/notes/acceleration.md b/docs/source/notes/acceleration.md
index 6b088b4..4bb258f 100644
--- a/docs/source/notes/acceleration.md
+++ b/docs/source/notes/acceleration.md
@@ -1,6 +1,14 @@
(acceleration)=
# OpenDelta+
- We are working on testing and improving the functionality with work with other acceleration packages for model training and inference. For example, [deepspeed](https://github.com/microsoft/DeepSpeed), [BMInf](https://github.com/OpenBMB/BMInf).
-Feel free to contact us via email (shengdinghu@gmail.com) if you have any suggestion.
+## BMTrain
+
+- [BMTrain](https://github.com/OpenBMB/BMTrain) is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
+- [ModelCenter](https://github.com/OpenBMB/ModelCenter) implements pre-trained language models (PLMs) based on the backend OpenBMB/BMTrain. ModelCenter supports Efficient, Low-Resource, Extendable model usage and distributed training.
+
+Now we have the LoraModel, AdapterModel, CompacterModel, ParallelAdapterModel, LowRankAdapterModel fully supported the distributed training with BMTrain and ModelCenter. Please try is out in
+
+
+## Huggingface Accelerate
+
\ No newline at end of file
diff --git a/docs/source/notes/update.md b/docs/source/notes/update.md
index 26626d2..e121178 100644
--- a/docs/source/notes/update.md
+++ b/docs/source/notes/update.md
@@ -1,5 +1,10 @@
# Update Logs and Known Issues
+## Version 0.3.2
+- We support BMTrain to accelerate the training, and parallelize the training of models that are hard to fit in a single GPU. Check [tutorial/2_with_bmtrain.py](https://github.com/thunlp/OpenDelta/tree/main/examples/tutorial/2_with_bmtrain.py)
+- We add a functionality to [inspect the optimizer](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py). The user can see the number of trainable parameters in the optimizer and verify that opendelta is being used correctly.
+- We move the functions to inspect the delta models into [inspect.py](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py)
+
## Version 0.3.1
- We update [must_try.py](https://github.com/thunlp/OpenDelta/tree/main/examples/unittest/must_try.py) for a simple introduction of the core functionality of OpenDelta.
- Thanks to [Weilin Zhao](https://github.com/Achazwl) We merge a long-developed branch parallel_adapter into the main branch.
diff --git a/examples/unittest/test_bmtrain.py b/examples/tutorial/2_with_bmtrain.py
similarity index 85%
rename from examples/unittest/test_bmtrain.py
rename to examples/tutorial/2_with_bmtrain.py
index 9096fe4..ba202c9 100644
--- a/examples/unittest/test_bmtrain.py
+++ b/examples/tutorial/2_with_bmtrain.py
@@ -1,3 +1,5 @@
+# adapted from https://github.com/OpenBMB/ModelCenter/blob/main/examples/bert/finetune_bert.py
+
import time
import os
@@ -15,6 +17,11 @@ from model_center.utils import print_inspect
from model_center.layer import Linear
from model_center.dataset import DistributedDataLoader
+import opendelta as od
+from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
+from opendelta.utils.inspect import inspect_optimizer_statistics
+print("before modify")
+
class BertModel(torch.nn.Module):
def __init__(self, args, num_types):
super().__init__()
@@ -41,6 +48,31 @@ def get_model(args):
"WiC" : 2,
}
model = BertModel(args, num_types[args.dataset_name])
+ od.Visualization(model).structure_graph()
+
+
+ if args.delta_type == "lora":
+ delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
+ elif args.delta_type == "bitfit":
+ delta_model = BitFitModel(backbone_model=model, modified_modules=['self_att', 'ffn', 'layernorm'], backend='bmt') #TODO: fix bug
+ elif args.delta_type == "adapter":
+ delta_model = AdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
+ elif args.delta_type == "compacter":
+ delta_model = CompacterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
+ elif args.delta_type == "low_rank_adapter":
+ delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
+ elif args.delta_type == "parallel_adapter":
+ delta_model = ParallelAdapterModel(backbone_model=model, modified_modules=['self_att', 'self_att', 'ffn.ffn', 'ffn.ffn'], backend='bmt')
+
+
+
+ print("after modify")
+ delta_model.log()
+ # This will visualize the backbone after modification and other information.
+
+ delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
+ print("after freeze")
+ delta_model.log()
return model
def get_optimizer(args, model):
@@ -93,6 +125,8 @@ def setup_model_and_optimizer(args):
bmt.synchronize()
# get the optimizer and lr_scheduler
optimizer = get_optimizer(args, model)
+
+ inspect_optimizer_statistics(optimizer)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
# get the memory usage
@@ -124,7 +158,7 @@ def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
optim_manager.add_optimizer(optimizer, lr_scheduler)
- print_inspect(model, '*')
+ # print_inspect(model, '*') # too much output
for epoch in range(12):
dataloader = {
@@ -172,6 +206,8 @@ def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
torch.cuda.synchronize()
elapsed_time = time.time() - st_time
+ # from IPython import embed; embed(header="25252")
+
bmt.print_rank(
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format(
epoch,
diff --git a/examples/tutorial/2_with_bmtrain.sh b/examples/tutorial/2_with_bmtrain.sh
new file mode 100644
index 0000000..f25013d
--- /dev/null
+++ b/examples/tutorial/2_with_bmtrain.sh
@@ -0,0 +1,37 @@
+#! /bin/bash
+
+MASTER_ADDR=localhost
+MASTER_PORT=12345
+NNODES=1
+NODE_RANK=0
+GPUS_PER_NODE=4
+
+DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT"
+
+BASE_PATH="./"
+VERSION="bert-large-cased"
+DATASET="BoolQ" # You can try other dataset listed in https://github.com/OpenBMB/ModelCenter/tree/main/examples/bert
+
+OPTS=""
+OPTS+=" --model-config ${VERSION}"
+OPTS+=" --base-path ${BASE_PATH}"
+OPTS+=" --dataset_name ${DATASET}"
+OPTS+=" --batch-size 64"
+OPTS+=" --lr 0.001" # You can use different learning rate to find optimal performance
+OPTS+=" --max-encoder-length 512"
+OPTS+=" --train-iters 1400"
+OPTS+=" --lr-decay-style constant"
+OPTS+=" --weight-decay 1e-2"
+OPTS+=" --clip-grad 10.0"
+OPTS+=" --loss-scale 128"
+OPTS+=" --delta_type low_rank_adapter" # You can use different delta type, listed in https://opendelta.readthedocs.io/en/latest/notes/acceleration.html#BMTrain
+
+CMD="python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} ${BASE_PATH}2_with_bmtrain.py ${OPTS}"
+echo ${CMD}
+
+${CMD} 2>&1 | tee ${BASE_PATH}/tmp/logs/bmt_bert_boolq_finetune-${VERSION}-${DATASET}.log
+
diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md
index 815a114..e0f2a6e 100644
--- a/examples/tutorial/README.md
+++ b/examples/tutorial/README.md
@@ -13,4 +13,23 @@ requirement:
```
pip install openprompt
-```
\ No newline at end of file
+```
+
+## 2_with_bmtrain.py
+1. install necessary packages:
+```
+pip install git+https://github.com/OpenBMB/BMTrain.git
+pip install git+git@github.com:OpenBMB/ModelCenter.git
+```
+2. download dataset from https://super.gluebenchmark.com/tasks, e.g.,
+```
+mkdir down_data
+cd down_data
+wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/BoolQ.zip
+unzip BoolQ.zip
+```
+3. Run the shell scripts, change `NNODES`,`GPUS_PER_NODE` according to your computational resources.
+```
+bash 2_with_bmtrain.sh
+```
+
diff --git a/examples/unittest/must_try.py b/examples/unittest/must_try.py
index 30a7b5c..8448e41 100644
--- a/examples/unittest/must_try.py
+++ b/examples/unittest/must_try.py
@@ -67,4 +67,16 @@ delta2.detach()
# say we add lora to the last four layer of the decoder of t5, with lora rank=5
delta_config3 = AutoDeltaConfig.from_dict({"delta_type":"lora", "modified_modules":["[r]decoder.*((20)|(21)|(22)|(23)).*DenseReluDense\.wi"], "lora_r":5})
delta3 = AutoDeltaModel.from_config(delta_config3, backbone_model=wrapped_model)
-delta3.log()
\ No newline at end of file
+delta3.freeze_module()
+delta3.log()
+
+
+# add optimizer as normal
+from transformers import AdamW
+optimizer = AdamW(wrapped_model.parameters(), lr=3e-3)
+
+# inspect_optimizer
+from opendelta.utils.inspect import inspect_optimizer_statistics
+inspect_optimizer_statistics(optimizer)
+
+
diff --git a/examples/unittest/test_parallel_adapter.py b/examples/unittest/test_parallel_adapter.py
new file mode 100644
index 0000000..7773946
--- /dev/null
+++ b/examples/unittest/test_parallel_adapter.py
@@ -0,0 +1,45 @@
+# use tranformers as usual.
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+t5 = AutoModelForSeq2SeqLM.from_pretrained("t5-large")
+t5_tokenizer = AutoTokenizer.from_pretrained("t5-large")
+# A running example
+inputs_ids = t5_tokenizer.encode("Is Harry Poter wrtten by JKrowling", return_tensors="pt")
+t5_tokenizer.decode(t5.generate(inputs_ids)[0])
+# >>> '? Is it Harry Potter?'
+
+
+# use existing delta models
+from opendelta import AutoDeltaModel, AutoDeltaConfig
+
+# use default configuration for cunstomized wrapped models which have PLMs inside. This is a common need for users.
+import torch.nn as nn
+class WrappedModel(nn.Module):
+ def __init__(self, inner_model):
+ super().__init__()
+ self.inner = inner_model
+ def forward(self, *args, **kwargs):
+ return self.inner(*args, **kwargs)
+
+wrapped_model = WrappedModel(WrappedModel(t5))
+
+# say we use LoRA
+delta_config = AutoDeltaConfig.from_dict({"delta_type":"parallel_adapter"})
+delta2 = AutoDeltaModel.from_config(delta_config, backbone_model=wrapped_model)
+delta2.log()
+# >>> root
+# -- inner
+# -- inner
+# -- encoder
+# -- block
+# -- 0
+# -- layer
+# ...
+# -- parallel_adapter
+# ...
+# -- 1
+# -- DenseRuleDense
+# -- wi
+# -- parallel_adapter
+# ...
+delta2.detach()
+delta2.log()
diff --git a/opendelta/__init__.py b/opendelta/__init__.py
index 431cfa1..1127d70 100644
--- a/opendelta/__init__.py
+++ b/opendelta/__init__.py
@@ -21,6 +21,7 @@ from .delta_models.adapter import AdapterModel
from .delta_models.prefix import PrefixModel
from .delta_models.soft_prompt import SoftPromptModel
from .delta_models.low_rank_adapter import LowRankAdapterModel
+from .delta_models.parallel_adapter import ParallelAdapterModel
from .utils.visualization import Visualization
diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py
index b7bf648..cd1006b 100644
--- a/opendelta/basemodel.py
+++ b/opendelta/basemodel.py
@@ -5,6 +5,7 @@ from multiprocessing.sharedctypes import Value
import os
from turtle import back
from opendelta.delta_configs import BaseDeltaConfig
+from opendelta.utils.inspect import inspect_module_statistics
from opendelta.utils.model_md5 import gen_model_hash
from opendelta.utils.signature import get_arg_names, signature
from typing import Optional, Union
@@ -27,6 +28,7 @@ from opendelta.utils.cuda import move_dict_to_cuda
import sys
from opendelta.utils.data_parallel import caller_map
+from opendelta.utils.backend import BackendMapping
logger = logging.get_logger(__name__)
def is_leaf_module(module):
@@ -94,6 +96,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
config_class = BaseDeltaConfig
default_unfrozen_modules = ["deltas"]
_need_pseudo_data = True
+ _supported_backends = ['hf']
def __init__(self,
backbone_model: nn.Module,
modified_modules: Optional[List[str]] = None,
@@ -101,6 +104,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
unfrozen_modules: Optional[List[str]] = None,
interactive_modify: Optional[Union[bool, int]] = False,
common_structure: Optional[bool] = False,
+ backend: Optional[str]= "hf", # select from ["hf", "bmt"]
):
nn.Module.__init__(self)
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
@@ -138,6 +142,10 @@ class DeltaBase(nn.Module, SaveLoadMixin):
self.unfrozen_modules = self.default_unfrozen_modules
if self.common_structure and self.structure_mapping is None:
raise RuntimeError("Using common structure but the structure mapping is None")
+ if backend not in self._supported_backends:
+ raise RuntimeError("Currently, backend `{}` is not supported for `{}`".format(backend, self.__class__.__name__))
+ self.backend = backend
+ self.backend_mapping = BackendMapping(backend)
def forward(self, *args, **kwargs) -> RuntimeError:
r"""
@@ -369,10 +377,11 @@ class DeltaBase(nn.Module, SaveLoadMixin):
_auto_dummy_fail = False
try:
module(**dummy_inputs)
- except:
+ except Exception as e:
_auto_dummy_fail = True
- if _auto_dummy_fail:
- raise AttributeError(f"\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
+
+ if _auto_dummy_fail and _auto_dummy:
+ raise AttributeError(f"str({e})\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
@@ -682,65 +691,16 @@ class DeltaBase(nn.Module, SaveLoadMixin):
from opendelta import Visualization
Visualization(module).structure_graph()
- self.get_statistics(module)
+ self.stat = inspect_module_statistics(module, verbose=False)
if trainable_ratio:
- logger.info("Trainable Ratio: {:2f}%".format(self.stat['trainable_ratio']*100))
+ logger.info("Trainable Ratio: {}/{}={:.6f}%".format(self.stat['trainable_parameters'], self.stat['total_parameters'], self.stat['trainable_ratio']*100))
if delta_ratio:
- logger.info("Delta Parameter Ratio: {:2f}%".format(self.stat['delta_ratio']*100))
+ logger.info("Delta Parameter Ratio: {}/{}={:.6f}%".format(self.stat['delta_parameters'], self.stat['total_parameters'],self.stat['delta_ratio']*100))
if cuda_memory:
logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem']))
- def get_statistics(self, module=None):
- r"""Get the statistics of the parameters in the delta modules.
- Args:
- module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
-
- Returns:
- :obj:`dict`: The statistics of the parameters in the delta modules.
-
- """
- if module is None:
- module = self.backbone_model
-
- self.stat = {}
- n_trainable = self.num_trainable_parameters(module)
- n_total = self.num_total_parameters(module)
-
- self.stat['trainable_ratio'] = n_trainable/n_total
-
- n_delta = self.num_delta_parameters(module)
- n_total = self.num_total_parameters(module)
- self.stat['delta_ratio'] = n_delta/n_total
-
- cudamem = 0
- maxcudamem = 0
- for device_id in range(torch.cuda.device_count()):
- cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
- maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
- self.stat['cudamem'] = cudamem
- self.stat['maxcudamem'] = maxcudamem
-
-
-
- def num_delta_parameters(self, module: Optional[nn.Module]=None):
- r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
- compute the trainable rate.
-
- Args:
- module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
-
- Returns:
- :obj:`List[nn.Parameter]`
- """
- if module is None:
- module = self.backbone_model
- pnum_tot = 0
- for param in module.parameters():
- if hasattr(param, "_is_delta"):
- pnum_tot += param.numel()
- return pnum_tot
# Two functions for plug and remove the delta model.
def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
diff --git a/opendelta/delta_models/adapter.py b/opendelta/delta_models/adapter.py
index ff12314..4f191b4 100644
--- a/opendelta/delta_models/adapter.py
+++ b/opendelta/delta_models/adapter.py
@@ -60,13 +60,14 @@ class AdapterLayer(nn.Module, InterFaceMixin):
def get_layer_count(cls):
return cls.layer_count
- def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None):
+ def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None, backend="hf"):
super().__init__()
InterFaceMixin.__init__(self)
self.bottleneck_dim = bottleneck_dim
self.init_device = device
self.instantiated = False
self.non_linearity = non_linearity
+ self.backend=backend
self.layer_id = AdapterLayer.get_layer_count()
AdapterLayer.count_layer()
@@ -79,14 +80,16 @@ class AdapterLayer(nn.Module, InterFaceMixin):
else:
return self.init_device
- def instantiate(self, hidden_dim):
+ def instantiate(self, hiddens):
+ self.hidden_dim = hiddens.shape[-1]
+ self.hidden_dtype = hiddens.dtype
self.modulelist = nn.Sequential()
- self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.init_device))
+ self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.init_device, dtype=self.hidden_dtype))
# select non-linearity
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
- self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device))
+ self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device, dtype=self.hidden_dtype))
# TODO:
# If we want to have a layer norm on output, we apply it later after a separate residual connection
@@ -97,6 +100,9 @@ class AdapterLayer(nn.Module, InterFaceMixin):
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
+ if self.backend == 'bmt':
+ import bmtrain as bmt
+ self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
def _init_weight(self, module):
if isinstance(module, nn.Linear):
@@ -118,19 +124,29 @@ class AdapterLayer(nn.Module, InterFaceMixin):
raise TypeError
hiddens = self._transpose(hiddens)
- hiddens = self._convert_data_type(hiddens)
+ # if self.backend == 'hf':
+ # hiddens = self._convert_data_type(hiddens)
+ # elif self.backend == 'bmt': # if bmt, left the convertion to bmt
+ # pass
if not self.instantiated:
- self.hidden_dim = hiddens.shape[-1]
- logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
- self.instantiate(hidden_dim=self.hidden_dim)
-
+ # self.hidden_dim = hiddens.shape[-1]
+ # logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
+ self.instantiate(hiddens=hiddens)
+ # from IPython import embed; embed(header="14135315")
adapter_output = self.modulelist(hiddens)
modified_output = adapter_output + hiddens # TODO option: disable residual_connection
modified_output = self._reverse_transpose(modified_output)
- modified_output = self._reverse_data_type(modified_output)
+
+ # if self.backend == 'hf':
+ # # print("!"*100)
+ # modified_output = self._reverse_data_type(modified_output)
+ # elif self.backend == 'bmt': # if bmt, left the convertion to bmt
+ # print("!"*100)
+ # pass
+
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
@@ -184,20 +200,24 @@ class AdapterModel(DeltaBase):
modified_modules (:obj:`List[str]`): modules to add adapter after them.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
+ backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain.
"""
config_class = AdapterConfig
delta_type = "adapter"
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
+ _supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
bottleneck_dim: Optional[int]=24,
non_linearity: Optional[str]='gelu_new',
- modified_modules: Optional[bool] = None,
+ modified_modules: Optional[List[str]] = None,
+ exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[bool] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
+ backend: Optional[str] = 'hf',
):
DeltaBase.__init__(self,
backbone_model,
@@ -206,6 +226,7 @@ class AdapterModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
+ backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@@ -226,6 +247,6 @@ class AdapterModel(DeltaBase):
def new_module_like(self, module):
module_device = get_device(module)
- adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
+ adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer
diff --git a/opendelta/delta_models/bitfit.py b/opendelta/delta_models/bitfit.py
index c66fc5e..9d89548 100644
--- a/opendelta/delta_models/bitfit.py
+++ b/opendelta/delta_models/bitfit.py
@@ -2,6 +2,7 @@ from typing import Optional, Union
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase, is_leaf_module
+from opendelta.utils.cuda import get_device, get_dtype
import torch.nn as nn
import torch
@@ -28,17 +29,24 @@ class BitFitConfig(BaseDeltaConfig):
setattr(self, arg_name, locals()[arg_name])
class BiasLayer(nn.Module):
- def __init__(self, init_method="zero"):
+ def __init__(self, init_method="zero", dtype=None, device=None):
super().__init__()
self.init_method=init_method
self.instantiated = False
+ self.dtype = dtype
+ self.device = device
def instantiate(self, hidden_dim):
if self.init_method == "zero":
- self.bias = nn.Parameter(torch.zeros(hidden_dim))
+ self.bias = nn.Parameter(torch.zeros(hidden_dim, dtype=self.dtype, device=self.device))
else:
raise NotImplementedError
self.instantiated = True
+ try:
+ import bmtrain as bmt
+ self.bias = bmt.BMTrainModelWrapper(self.bias)
+ except:
+ pass
def post_forward(self, output):
r"""Presuming the first argument is the tensor to add bias along the last dimension.
@@ -106,6 +114,7 @@ class BitFitModel(DeltaBase):
config_class = BitFitConfig
delta_type = "bitfit"
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
+ _supported_backends = ['hf']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@@ -114,6 +123,7 @@ class BitFitModel(DeltaBase):
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
+ backend: Optional[str] = "hf",
):
DeltaBase.__init__(self,
backbone_model,
@@ -122,6 +132,7 @@ class BitFitModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
+ backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@@ -132,8 +143,9 @@ class BitFitModel(DeltaBase):
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
- self.modified_modules,
- )
+ self.modified_modules)
+
+
def update_module(self, module: nn.Module, key: str):
@@ -146,7 +158,10 @@ class BitFitModel(DeltaBase):
):
if is_leaf_module(module):
# if it is a leaf module, add bias to it regardless of its type.
- if isinstance(module, nn.Linear) or isinstance(module, nn.LayerNorm):
+ # if self.check_linear(module):
+ # self.add_bias_to_linear(module)
+ if self.backend_mapping.check_type(module, 'linear') or \
+ self.backend_mapping.check_type(module, 'layer_norm'):
self.add_bias_to_modules_have_bias_or_known_type(module)
else:
# for example, layer_norms, lm_heads.
@@ -154,6 +169,24 @@ class BitFitModel(DeltaBase):
else:
for n, c in module.named_modules():
self.add_bias_to_modules_have_bias_or_known_type(c)
+ # if self.check_linear(c):
+ # self.add_bias_to_linear(c)
+ # else:
+ # pass
+
+ # def add_bias_to_linear(self, c):
+ # if c.bias is None:
+ # bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
+ # self._reset_bias_parameters(c)
+ # try:
+ # import bmtrain as bmt
+ # bias = bmt.BMTrainModelWrapper(bias)
+ # except:
+ # pass
+ # c.register_parameter('bias', bias)
+ # self.delta_params.append(bias)
+ # else:
+ # self.add_bias_to_modules_have_bias_or_known_type(c)
def add_bias_to_modules_have_bias_or_known_type(self, c):
'''If it has bias, unfreeze it.
@@ -163,20 +196,27 @@ class BitFitModel(DeltaBase):
c.bias.requires_grad = True
self.delta_params.append(c.bias)
else:
- if isinstance(c, nn.Linear) or isinstance(c, nn.LayerNorm): # add bias
+ if self.backend_mapping.check_type(c, 'linear') or \
+ self.backend_mapping.check_type(c, 'layer_norm'):
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
+
+ self._reset_bias_parameters(c)
+ if self.backend == 'bmt':
+ import bmtrain as bmt
+ bias = bmt.BMTrainModelWrapper(bias)
+
c.register_parameter('bias', bias)
- self._reset_bias_parameters(c)
self.delta_params.append(bias)
- def add_bias_to_others(self, c):
- new_bias = BiasLayer()
- self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since
- # the name `bias` is reserved for some module such as roberta's LayerNorm.
+ def add_bias_to_others(self, c):
+ new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain?
+ if self.backend == 'bmt':
+ import bmtrain as bmt
+ new_bias = bmt.BMTrainModelWrapper(new_bias)
+
+ self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)
-
-
@staticmethod
def _reset_bias_parameters(linear_module):
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
diff --git a/opendelta/delta_models/compacter.py b/opendelta/delta_models/compacter.py
index bec3d66..9743aca 100644
--- a/opendelta/delta_models/compacter.py
+++ b/opendelta/delta_models/compacter.py
@@ -36,6 +36,7 @@ class HyperComplexAdapterLayer(nn.Module):
device=None,
use_bias_up_sampler=True,
use_bias_down_sampler=True,
+ backend = 'hf',
):
super().__init__()
self.reduction_factor = reduction_factor
@@ -55,14 +56,17 @@ class HyperComplexAdapterLayer(nn.Module):
self.use_bias_up_sampler=use_bias_up_sampler
self.use_bias_down_sampler=use_bias_down_sampler
self.device = device
+ self.backend = backend
self.instantiated = False
- def instantiate(self, hidden_dim):
- self.down_sample_size = hidden_dim // self.reduction_factor
+ def instantiate(self, hiddens):
+ self.hidden_dim = hiddens.shape[-1]
+ self.hidden_dtype = hiddens.dtype
+ self.down_sample_size = self.hidden_dim // self.reduction_factor
self.activation = Activations(self.non_linearity.lower()).to(self.device)
- self.down_sampler = PHMLinear(in_features=hidden_dim,
+ self.down_sampler = PHMLinear(in_features=self.hidden_dim,
out_features=self.down_sample_size,
bias=self.use_bias_down_sampler,
c_init=self.phm_c_init,
@@ -76,9 +80,10 @@ class HyperComplexAdapterLayer(nn.Module):
factorized_phm_rule=self.factorized_phm_rule,
phm_rank=self.phm_rank,
phm_init_range=self.phm_init_range,
- kronecker_prod=self.kronecker_prod).to(self.device)
+ kronecker_prod=self.kronecker_prod,
+ dtype = self.hidden_dtype).to(self.device)
self.up_sampler = PHMLinear(in_features=self.down_sample_size,
- out_features=hidden_dim,
+ out_features=self.hidden_dim,
bias=self.use_bias_up_sampler,
c_init=self.phm_c_init,
phm_dim=self.hypercomplex_division,
@@ -91,8 +96,14 @@ class HyperComplexAdapterLayer(nn.Module):
factorized_phm_rule=self.factorized_phm_rule,
phm_rank=self.phm_rank,
phm_init_range=self.phm_init_range,
- kronecker_prod=self.kronecker_prod).to(self.device)
+ kronecker_prod=self.kronecker_prod,
+ dtype = self.hidden_dtype).to(self.device)
self.instantiated = True
+ if self.backend == "bmt":
+ import bmtrain as bmt
+ self.activation = bmt.BMTrainModelWrapper(self.activation)
+ self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
+ self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
def post_forward(self, output):
@@ -109,9 +120,7 @@ class HyperComplexAdapterLayer(nn.Module):
raise TypeError
if not self.instantiated:
- self.hidden_dim = hiddens.shape[-1]
- logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
- self.instantiate(hidden_dim=self.hidden_dim)
+ self.instantiate(hiddens=hiddens)
z = self.down_sampler(hiddens)
@@ -186,6 +195,7 @@ class CompacterModel(DeltaBase):
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): whether using name-based addressing with a common structure mapping.
+ backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
reduction_factor (:obj:`int`, *optional*, default to ``16``): bottleneck_dim = hidden_dim//reduction_factor
non_linearity (:obj:`str`, *optional*, default to ``"gelu_new"``): The non linearity activation used in between the down
projecter and the up projecter.
@@ -211,6 +221,7 @@ class CompacterModel(DeltaBase):
config_class = CompacterConfig
delta_type = "compacter"
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
+ _supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model,
@@ -219,6 +230,7 @@ class CompacterModel(DeltaBase):
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
+ backend: Optional[str] = 'hf',
reduction_factor=16,
non_linearity="gelu_new",
phm_c_init="normal",
@@ -281,22 +293,6 @@ class CompacterModel(DeltaBase):
def new_module_like(self, module):
module_device = get_device(module)
- adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor,
- non_linearity=self.non_linearity,
- phm_c_init=self.phm_c_init,
- hypercomplex_division=self.hypercomplex_division,
- learn_phm=self.learn_phm,
- hypercomplex_nonlinearity=self.hypercomplex_nonlinearity,
- shared_phm_rule=self.shared_phm_rule,
- factorized_phm=self.factorized_phm,
- shared_W_phm=self.shared_W_phm,
- factorized_phm_rule=self.factorized_phm_rule,
- phm_rank=self.phm_rank,
- phm_init_range=self.phm_init_range,
- kronecker_prod=self.kronecker_prod,
- use_bias_up_sampler=self.use_bias_up_sampler,
- use_bias_down_sampler=self.use_bias_down_sampler,
- device=module_device
- )
+ adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor, non_linearity=self.non_linearity, phm_c_init=self.phm_c_init, hypercomplex_division=self.hypercomplex_division, learn_phm=self.learn_phm, hypercomplex_nonlinearity=self.hypercomplex_nonlinearity, shared_phm_rule=self.shared_phm_rule, factorized_phm=self.factorized_phm, shared_W_phm=self.shared_W_phm, factorized_phm_rule=self.factorized_phm_rule, phm_rank=self.phm_rank, phm_init_range=self.phm_init_range, kronecker_prod=self.kronecker_prod, use_bias_up_sampler=self.use_bias_up_sampler, use_bias_down_sampler=self.use_bias_down_sampler, device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer
diff --git a/opendelta/delta_models/layers/hypercomplex_linear.py b/opendelta/delta_models/layers/hypercomplex_linear.py
index e0ed589..e4f94c2 100644
--- a/opendelta/delta_models/layers/hypercomplex_linear.py
+++ b/opendelta/delta_models/layers/hypercomplex_linear.py
@@ -62,7 +62,7 @@ def matvec_product(W: torch.Tensor, x: torch.Tensor,
else:
H = kronecker_product_einsum_batched(phm_rule, W).sum(0)
- y = torch.matmul(input=x, other=H)
+ y = torch.matmul(input=x.to(H.dtype), other=H).to(x.dtype)
if bias is not None:
y += bias
return y
@@ -84,7 +84,8 @@ class PHMLinear(torch.nn.Module):
factorized_phm_rule=False,
phm_rank = 1,
phm_init_range=0.0001,
- kronecker_prod=False) -> None:
+ kronecker_prod=False,
+ dtype=torch.float) -> None:
super(PHMLinear, self).__init__()
assert w_init in ["phm", "glorot-normal", "glorot-uniform", "normal"]
assert c_init in ["normal", "uniform"]
@@ -104,12 +105,12 @@ class PHMLinear(torch.nn.Module):
self.factorized_phm_rule = factorized_phm_rule
if not self.shared_phm_rule:
if self.factorized_phm_rule:
- self.phm_rule_left = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, 1),
+ self.phm_rule_left = nn.Parameter(torch.empty((phm_dim, phm_dim, 1), dtype=dtype),
requires_grad=learn_phm)
- self.phm_rule_right = nn.Parameter(torch.FloatTensor(phm_dim, 1, phm_dim),
+ self.phm_rule_right = nn.Parameter(torch.empty((phm_dim, 1, phm_dim), dtype=dtype),
requires_grad=learn_phm)
else:
- self.phm_rule = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, phm_dim),
+ self.phm_rule = nn.Parameter(torch.empty((phm_dim, phm_dim, phm_dim), dtype=dtype),
requires_grad=learn_phm)
self.bias_flag = bias
self.w_init = w_init
@@ -118,15 +119,15 @@ class PHMLinear(torch.nn.Module):
self.factorized_phm = factorized_phm
if not self.shared_W_phm:
if self.factorized_phm:
- self.W_left = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self.phm_rank)),
+ self.W_left = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self.phm_rank), dtype=dtype),
requires_grad=True)
- self.W_right = nn.Parameter(torch.Tensor(size=(phm_dim, self.phm_rank, self._out_feats_per_axis)),
+ self.W_right = nn.Parameter(torch.empty((phm_dim, self.phm_rank, self._out_feats_per_axis), dtype=dtype),
requires_grad=True)
else:
- self.W = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self._out_feats_per_axis)),
+ self.W = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self._out_feats_per_axis), dtype=dtype),
requires_grad=True)
if self.bias_flag:
- self.b = nn.Parameter(torch.Tensor(out_features))
+ self.b = nn.Parameter(torch.empty(out_features, dtype=dtype), requires_grad=True)
else:
self.register_parameter("b", None)
self.reset_parameters()
diff --git a/opendelta/delta_models/layers/low_rank_linear.py b/opendelta/delta_models/layers/low_rank_linear.py
index 61ab92d..95c1466 100644
--- a/opendelta/delta_models/layers/low_rank_linear.py
+++ b/opendelta/delta_models/layers/low_rank_linear.py
@@ -6,17 +6,17 @@ from opendelta.delta_models.layers.init import glorot_uniform, glorot_normal
class LowRankLinear(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int, rank: int = 1,
- bias: bool = True, w_init: str = "glorot-uniform"):
+ bias: bool = True, w_init: str = "glorot-uniform", dtype=torch.float):
super(LowRankLinear, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.rank = rank
self.bias = bias
self.w_init = w_init
- self.W_left = nn.Parameter(torch.Tensor(size=(input_dim, rank)), requires_grad=True)
- self.W_right = nn.Parameter(torch.Tensor(size=(rank, output_dim)), requires_grad=True)
+ self.W_left = nn.Parameter(torch.empty((input_dim, rank), dtype=dtype),requires_grad=True)
+ self.W_right = nn.Parameter(torch.empty((rank, output_dim), dtype=dtype), requires_grad=True)
if bias:
- self.b = nn.Parameter(torch.Tensor(output_dim))
+ self.b = nn.Parameter(torch.empty(output_dim, dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
@@ -33,7 +33,7 @@ class LowRankLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
W = self.W_left*self.W_right
- output = torch.matmul(input=x, other=W)
+ output = torch.matmul(input=x.to(W.dtype), other=W).to(x.dtype)
if self.bias:
output += self.b
return output
diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py
index 5806f6c..0446842 100644
--- a/opendelta/delta_models/lora.py
+++ b/opendelta/delta_models/lora.py
@@ -97,12 +97,14 @@ class LoraModel(DeltaBase):
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
+ backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
"""
config_class = LoraConfig
delta_type = "lora"
default_modified_modules = ['attn@.q@', 'attn@.v@']
+ _supported_backends = ['hf', 'bmt']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@@ -114,6 +116,7 @@ class LoraModel(DeltaBase):
exclude_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
+ backend: Optional[str] = "hf",
):
DeltaBase.__init__(self,
backbone_model,
@@ -121,6 +124,7 @@ class LoraModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
+ backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@@ -144,15 +148,16 @@ class LoraModel(DeltaBase):
pass
def new_module_like(self, child_module):
- if isinstance(child_module, nn.Linear):
- in_features, out_features = child_module.in_features, child_module.out_features
- new_module = LowRankLinear(in_features = in_features,
- out_features = out_features,
- weight = child_module.weight,
- r=self.lora_r,
- lora_alpha=self.lora_alpha,
- lora_dropout=self.lora_dropout)
- self.delta_modules.append(new_module)
- else:
- raise NotImplementedError
+ in_features, out_features = child_module.in_features, child_module.out_features
+ new_module = LowRankLinear(in_features = in_features,
+ out_features = out_features,
+ weight = child_module.weight,
+ r=self.lora_r,
+ lora_alpha=self.lora_alpha,
+ lora_dropout=self.lora_dropout)
+ if self.backend == "bmt":
+ import bmtrain as bmt
+ new_module = bmt.BMTrainModelWrapper(new_module)
+
+ self.delta_modules.append(new_module)
return new_module
diff --git a/opendelta/delta_models/lora_old.py b/opendelta/delta_models/lora_old.py
deleted file mode 100644
index e548fd4..0000000
--- a/opendelta/delta_models/lora_old.py
+++ /dev/null
@@ -1,127 +0,0 @@
-from typing import Optional, Union
-
-from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
-from opendelta.utils.name_based_addressing import *
-from opendelta.basemodel import DeltaBase
-from transformers.models.t5 import T5ForConditionalGeneration
-import loralib as lora
-import torch.nn as nn
-from opendelta import BaseDeltaConfig
-
-class LoraConfig(BaseDeltaConfig):
- r"""
- This is the configuration class to store the configuration of a :py:class:`~LoraModel`
-
- """
- def __init__(
- self,
- lora_r=8,
- lora_alpha=16,
- lora_dropout=0.0,
- **kwargs
- ):
- super().__init__(**kwargs)
- arg_names = get_arg_names_inside_func(self.__init__)
- for arg_name in arg_names:
- if not hasattr(self, arg_name): # the arg has not been registered in parent config
- setattr(self, arg_name, locals()[arg_name])
-
-
-class LoraModel(DeltaBase):
- r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models `_ .
- Thanks for their `loralib `_, we use loralib.linear
- to replace the linear layer of the backbone model.
-
- class attributes:
- - default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the
- attention layer. However, other linears can also be modified, and may lead to better performance.
-
- .. note::
- modified_modules should point to linear layer. We currently don't support broadcast to all linears in
- a module's child modules.
-
- - delta_type = "lora"
-
-
- Args:
- backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
- lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.
- lora_alpha (:obj:`int`, *optional*): A hyper-parameter to control the init scale of loralib.linear .
- lora_dropout (:obj:`float`, *optional*): The dropout rate in lora.linear.
- modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
- the implemented ones)
- unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
- together with the prefix parameters.
- common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
-
- """
-
- config_class = LoraConfig
- delta_type = "lora"
- default_modified_modules = ['attn.q', 'attn.v']
- def __init__(self,
- backbone_model: nn.Module,
- lora_r=8,
- lora_alpha=16,
- lora_dropout=0.0,
- modified_modules: Optional[List[str]] = None,
- exclude_modules: Optional[List[str]] = None,
- unfrozen_modules: Optional[List[str]] = None,
- common_structure: Optional[bool] = None,
- interactive_modify: Optional[Union[bool, int]] = False,
- ):
- DeltaBase.__init__(self,
- backbone_model,
- modified_modules=modified_modules,
- exclude_modules=exclude_modules,
- unfrozen_modules=unfrozen_modules,
- common_structure=common_structure,
- interactive_modify=interactive_modify,
- )
- arg_names = get_arg_names_inside_func(self.__init__)
- for arg_name in arg_names:
- if not hasattr(self, arg_name): # not registered in parent class
- setattr(self, arg_name, locals()[arg_name])
-
- self.delta_modules = nn.ModuleList()
-
- self.add_all_delta_to_backbone(self.backbone_model,
- self.modified_modules,
- )
-
-
-
- def update_module(self, module: nn.Module, key: str):
- parent_ref, child_name, child_ref = self.find_module(module, key)
- new_module = self.new_module_like(child_module=child_ref)
- self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora")
-
- def _pseudo_data_to_instantiate(self, module):
- # no need to pass pseudo input, so overwrite it
- pass
-
- def new_module_like(self, child_module):
- if isinstance(child_module, nn.Linear):
- in_features, out_features = child_module.in_features, child_module.out_features
- new_module = lora.Linear(in_features=in_features,
- out_features=out_features,
- r=self.lora_r,
- lora_alpha=self.lora_alpha,
- lora_dropout=self.lora_dropout)
- new_module.weight = child_module.weight
- new_module.bias = child_module.bias # if bias is None, also copy
- else:
- raise NotImplementedError
- return new_module
-
-
-
- def mark_as_delta(self, module: nn.Module = None):
- if module is None:
- module=self
- for n, p in module.named_parameters():
- param_name = n.split(".")[-1]
- if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter.
- setattr(p, "_is_delta", True)
-
-
diff --git a/opendelta/delta_models/low_rank_adapter.py b/opendelta/delta_models/low_rank_adapter.py
index f23b00c..eaef90a 100644
--- a/opendelta/delta_models/low_rank_adapter.py
+++ b/opendelta/delta_models/low_rank_adapter.py
@@ -47,7 +47,8 @@ class LowRankAdapter(nn.Module):
non_linearity="gelu_new",
low_rank_w_init="glorot-uniform",
low_rank_rank=1,
- device=None):
+ device=None,
+ backend='hf'):
super().__init__()
self.reduction_factor = reduction_factor
self.non_linearity = non_linearity
@@ -55,20 +56,31 @@ class LowRankAdapter(nn.Module):
self.low_rank_rank = low_rank_rank
self.device = device
self.instantiated = False
+ self.backend=backend
- def instantiate(self, hidden_dim):
+ def instantiate(self, hiddens):
+ self.hidden_dim = hiddens.shape[-1]
+ self.hidden_dtype = hiddens.dtype
- self.down_sample_size = hidden_dim // self.reduction_factor
+ self.down_sample_size = self.hidden_dim // self.reduction_factor
self.activation = Activations(self.non_linearity.lower()).to(self.device)
- self.down_sampler = LowRankLinear(hidden_dim, self.down_sample_size,
+ self.down_sampler = LowRankLinear(self.hidden_dim, self.down_sample_size,
w_init=self.low_rank_w_init,
- rank=self.low_rank_rank).to(self.device)
- self.up_sampler = LowRankLinear(self.down_sample_size, hidden_dim,
+ rank=self.low_rank_rank,
+ dtype=self.hidden_dtype).to(self.device)
+ self.up_sampler = LowRankLinear(self.down_sample_size, self.hidden_dim,
w_init=self.low_rank_w_init,
- rank=self.low_rank_rank).to(self.device)
+ rank=self.low_rank_rank,
+ dtype=self.hidden_dtype).to(self.device)
self.instantiated = True
+ if self.backend == 'bmt':
+ import bmtrain as bmt
+ self.activation = bmt.BMTrainModelWrapper(self.activation)
+ self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
+ self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
+
def post_forward(self, output):
r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter,
@@ -84,10 +96,7 @@ class LowRankAdapter(nn.Module):
raise TypeError
if not self.instantiated:
- self.hidden_dim = hiddens.shape[-1]
- logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
- self.instantiate(hidden_dim=self.hidden_dim)
-
+ self.instantiate(hiddens = hiddens)
z = self.down_sampler(hiddens)
z = self.activation(z)
@@ -148,6 +157,7 @@ class LowRankAdapterModel(DeltaBase):
config_class = LowRankAdapterConfig
delta_type = "low_rank_adapter"
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
+ _supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
@@ -160,6 +170,7 @@ class LowRankAdapterModel(DeltaBase):
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
+ backend: Optional[str] = 'hf',
):
DeltaBase.__init__(self,
backbone_model,
@@ -168,6 +179,7 @@ class LowRankAdapterModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
+ backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@@ -203,6 +215,6 @@ class LowRankAdapterModel(DeltaBase):
non_linearity = self.non_linearity,
low_rank_w_init = self.low_rank_w_init,
low_rank_rank = self.low_rank_rank,
- device=module_device)
+ device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer
diff --git a/opendelta/delta_models/parallel_adapter.py b/opendelta/delta_models/parallel_adapter.py
index 1024394..d354587 100644
--- a/opendelta/delta_models/parallel_adapter.py
+++ b/opendelta/delta_models/parallel_adapter.py
@@ -25,30 +25,36 @@ class ParallelAdapterLayer(nn.Module):
def get_layer_count(cls):
return cls.layer_count
- def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None):
+ def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None, backend='hf'):
super().__init__()
self.bottleneck_dim = bottleneck_dim
self.device = device
self.instantiated = False
self.non_linearity = non_linearity
self.scaled = scaled
+ self.backend = backend
self.layer_id = ParallelAdapterLayer.get_layer_count()
ParallelAdapterLayer.count_layer()
- def instantiate(self, hidden_dim):
+ def instantiate(self, hiddens):
+ self.hidden_dim = hiddens.shape[-1]
+ self.hidden_dtype = hiddens.dtype
self.modulelist = nn.Sequential()
- self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.device))
+ self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.device, dtype=self.hidden_dtype))
# select non-linearity
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
- self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
+ self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device, dtype=self.hidden_dtype))
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
+ if self.backend == 'bmt':
+ import bmtrain as bmt
+ self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
def _init_weight(self, module):
if isinstance(module, nn.Linear):
@@ -71,9 +77,8 @@ class ParallelAdapterLayer(nn.Module):
if not self.instantiated:
- self.hidden_dim = hiddens.shape[-1]
- logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
- self.instantiate(hidden_dim=self.hidden_dim)
+ # logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
+ self.instantiate(hiddens = hiddens)
self.adapter_output = self.modulelist(hiddens) * self.scaled
@@ -141,12 +146,14 @@ class ParallelAdapterModel(DeltaBase):
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
+ backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
"""
config_class = ParallelAdapterConfig
delta_type = "parallel_adapter"
default_modified_modules = ["attn@", "attn@", "ff@.w1@", "ff@.w2@"]
# default_modified_modules = ["attn", "attn", "ff.w1", "ff.w2"]
+ _supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
@@ -156,7 +163,8 @@ class ParallelAdapterModel(DeltaBase):
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[bool] = None,
common_structure: Optional[bool] = None,
- interactive_modify: Optional[Union[bool, int]] = False,
+ interactive_modify: Optional[Union[bool, int]] = False,
+ backend: Optional[str] = "hf",
):
DeltaBase.__init__(self,
backbone_model,
@@ -165,6 +173,7 @@ class ParallelAdapterModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
+ backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@@ -193,7 +202,7 @@ class ParallelAdapterModel(DeltaBase):
def new_module_like(self, module):
module_device = get_device(module)
- adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
+ adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer
\ No newline at end of file
diff --git a/opendelta/delta_models/prefix.py b/opendelta/delta_models/prefix.py
index f64df2c..777a8e1 100644
--- a/opendelta/delta_models/prefix.py
+++ b/opendelta/delta_models/prefix.py
@@ -516,6 +516,7 @@ class PrefixModel(DeltaBase):
config_class = PrefixConfig
delta_type = "prefix"
default_modified_modules = ['attn@']
+ _supported_backends = ['hf']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
diff --git a/opendelta/delta_models/soft_prompt.py b/opendelta/delta_models/soft_prompt.py
index 95854dd..6453368 100644
--- a/opendelta/delta_models/soft_prompt.py
+++ b/opendelta/delta_models/soft_prompt.py
@@ -161,6 +161,7 @@ class SoftPromptModel(DeltaBase):
config_class = SoftPromptConfig
delta_type = "soft_prompt"
default_modified_modules = ["root"] # not used
+ _supported_backends = ['hf'] #'bmt']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@@ -222,5 +223,10 @@ class SoftPromptModel(DeltaBase):
init_range = self.init_range,
device = module_device,
)
+ try:
+ import bmtrain as bmt
+ soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
+ except:
+ pass
self.delta_modules.append(soft_prompt_layer)
return soft_prompt_layer
diff --git a/opendelta/utils/backend.py b/opendelta/utils/backend.py
new file mode 100644
index 0000000..0b5b124
--- /dev/null
+++ b/opendelta/utils/backend.py
@@ -0,0 +1,110 @@
+
+
+import importlib
+
+
+class BackendMapping:
+ """
+ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
+
+ Args:
+
+ - config_mapping: The map model type to config class
+ - model_mapping: The map model type to model (or tokenizer) class
+ """
+
+ def __init__(self, backend):
+ self.backend = backend
+ assert backend in ['hf', 'bmt'], "Backend should be one of 'hf', 'bmt'. "
+ if backend == 'hf':
+ self.backend_mapping = {
+ "linear": "torch.nn.Linear",
+ "layer_norm": "torch.nn.LayerNorm",
+ "module": "torch.nn.Module",
+ "parameter": "torch.nn.Parameter"
+ }
+ elif backend == 'bmt':
+ self.backend_mapping = {
+ "linear": "model_center.layer.Linear",
+ "layer_norm": "model_center.layer.LayerNorm",
+ "module": "bmtrain.layer.DistributedModule",
+ "parameter": "bmtrain.nn.DistributedParameter"
+ }
+ self.registered = {}
+
+ def load(self, model_type):
+ if model_type not in self.registered:
+ splited = self.backend_mapping[model_type].split(".")
+ module_name, class_name = ".".join(splited[:-1]), splited[-1]
+ module = importlib.import_module(module_name)
+ the_class = getattr(module, class_name)
+ self.registered[model_type] = the_class
+ return self.registered[model_type]
+
+ def check_type(self, module, expect_type):
+ the_class = self.load(expect_type)
+ if isinstance(module, the_class):
+ return True
+ else:
+ return False
+
+
+ # def keys(self):
+ # mapping_keys = [
+ # self._load_attr_from_module(key, name)
+ # for key, name in self._config_mapping.items()
+ # if key in self._model_mapping.keys()
+ # ]
+ # return mapping_keys + list(self._extra_content.keys())
+
+ # def get(self, key, default):
+ # try:
+ # return self.__getitem__(key)
+ # except KeyError:
+ # return default
+
+ # def __bool__(self):
+ # return bool(self.keys())
+
+ # def values(self):
+ # mapping_values = [
+ # self._load_attr_from_module(key, name)
+ # for key, name in self._model_mapping.items()
+ # if key in self._config_mapping.keys()
+ # ]
+ # return mapping_values + list(self._extra_content.values())
+
+ # def items(self):
+ # mapping_items = [
+ # (
+ # self._load_attr_from_module(key, self._config_mapping[key]),
+ # self._load_attr_from_module(key, self._model_mapping[key]),
+ # )
+ # for key in self._model_mapping.keys()
+ # if key in self._config_mapping.keys()
+ # ]
+ # return mapping_items + list(self._extra_content.items())
+
+ # def __iter__(self):
+ # return iter(self.keys())
+
+ # def __contains__(self, item):
+ # if item in self._extra_content:
+ # return True
+ # if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
+ # return False
+ # model_type = self._reverse_config_mapping[item.__name__]
+ # return model_type in self._model_mapping
+
+ # def register(self, key, value):
+ # """
+ # Register a new model in this mapping.
+ # """
+ # if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
+ # model_type = self._reverse_config_mapping[key.__name__]
+ # if model_type in self._model_mapping.keys():
+ # raise ValueError(f"'{key}' is already used by a Transformers model.")
+
+ # self._extra_content[key] = value
+
+
diff --git a/opendelta/utils/cuda.py b/opendelta/utils/cuda.py
index 5f237a7..fcfbc10 100644
--- a/opendelta/utils/cuda.py
+++ b/opendelta/utils/cuda.py
@@ -17,6 +17,20 @@ def get_device(module : Union[nn.Module, nn.Parameter]):
else:
raise RuntimeError("The module is paralleled acrossed device, please get device in a inner module")
+def get_dtype(module : Union[nn.Module, nn.Parameter]):
+ if not (isinstance(module, nn.Module) \
+ or isinstance(module, nn.Parameter)):
+ raise RuntimeError("module is not a instance of torch.nn.Module")
+ if hasattr(module, 'dtype'):
+ return module.dtype
+ else:
+ params_dtypes = [p.dtype for p in module.parameters()]
+ if len(params_dtypes) == 0:
+ return None
+ elif len(set(params_dtypes))==1:
+ return params_dtypes[0]
+ else:
+ raise RuntimeError("The module has multiple dtype, please get device in a inner module")
def move_dict_to_cuda(dict_of_tensor, device):
for key in dict_of_tensor:
diff --git a/opendelta/utils/inspect.py b/opendelta/utils/inspect.py
new file mode 100644
index 0000000..830298e
--- /dev/null
+++ b/opendelta/utils/inspect.py
@@ -0,0 +1,112 @@
+
+import torch
+import torch.nn as nn
+from typing import Optional
+import opendelta.utils.logging as logging
+
+logger = logging.get_logger(__name__)
+
+
+def inspect_module_statistics(module: Optional[nn.Module]=None, verbose=True):
+ r"""Get the statistics of the parameters in the delta modules.
+
+ Args:
+ module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
+
+ Returns:
+ :obj:`dict`: The statistics of the parameters in the delta modules.
+
+ """
+
+ stat = {}
+ n_trainable = num_trainable_parameters(module)
+ n_total = num_total_parameters(module)
+
+ stat['total_parameters'] = n_total
+ stat['trainable_parameters'] = n_trainable
+
+ stat['trainable_ratio'] = n_trainable/n_total
+
+ n_delta = num_delta_parameters(module)
+ n_total = num_total_parameters(module)
+ stat['delta_parameters'] = n_delta
+ stat['delta_ratio'] = n_delta/n_total
+
+ cudamem = 0
+ maxcudamem = 0
+ for device_id in range(torch.cuda.device_count()):
+ cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
+ maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
+ stat['cudamem'] = cudamem
+ stat['maxcudamem'] = maxcudamem
+
+ if verbose:
+ logger.info(stat)
+
+ return stat
+
+def num_trainable_parameters(module: Optional[nn.Module]=None):
+ r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
+ compute the trainable rate.
+
+ Args:
+ module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
+
+ Returns:
+ :obj:`List[nn.Parameter]`
+ """
+ pnum_tot = 0
+ for param in module.parameters():
+ if param.requires_grad:
+ pnum_tot += param.numel()
+ return pnum_tot
+
+
+def num_total_parameters(module: Optional[nn.Module]=None):
+ r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
+ compute the trainable rate.
+
+ Args:
+ module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
+
+ Returns:
+ :obj:`List[nn.Parameter]`
+ """
+ pnum_tot = 0
+ for param in module.parameters():
+ pnum_tot += param.numel()
+ return pnum_tot
+
+def num_delta_parameters(module: Optional[nn.Module]=None):
+ r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
+ compute the trainable rate.
+
+ Args:
+ module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
+
+ Returns:
+ :obj:`List[nn.Parameter]`
+ """
+ pnum_tot = 0
+ for param in module.parameters():
+ if hasattr(param, "_is_delta"):
+ pnum_tot += param.numel()
+ return pnum_tot
+
+def inspect_optimizer_statistics(optimizer, verbose=True):
+ stats = {}
+ for id, param_group in enumerate(optimizer.param_groups):
+ stat = {}
+ fine_grain_info = [(p.numel(), p.requires_grad) for p in param_group['params']]
+ stat['total_parameters'] = sum(n for n, r in fine_grain_info)
+ stat['trainable_parameters'] = sum(n for n, r in fine_grain_info if r)
+ stat['trainable_ratio'] = "{:.6f}%".format(stat['trainable_parameters']/stat['total_parameters']*100)
+ for key in param_group:
+ if key != 'params':
+ stat[key] = param_group[key]
+ stats[f'param_group_{id}'] = stat
+
+ if verbose:
+ logger.info(f"optimizer info: {stats}")
+
+ return stat
diff --git a/setup.py b/setup.py
index 27c0313..711f5a4 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ def get_requirements():
with open('README.md', 'r') as f:
setuptools.setup(
name = 'opendelta',
- version = "0.3.1",
+ version = "0.3.2",
description = "An open source framework for delta learning (parameter efficient learning).",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",