From 739bebbb8cfa98b09035e1b436f31186a3399f8f Mon Sep 17 00:00:00 2001 From: Achazwl Date: Sat, 3 Sep 2022 10:12:12 +0000 Subject: [PATCH] init --- examples/tutorial/2_with_bmtrain.py | 50 +++++++ examples/tutorial/2_with_bmtrain.sh | 1 + opendelta/delta_models/adapter.py | 5 + opendelta/delta_models/bitfit.py | 45 +++++-- opendelta/delta_models/compacter.py | 7 + .../layers/hypercomplex_linear.py | 2 +- .../delta_models/layers/low_rank_linear.py | 2 +- opendelta/delta_models/lora.py | 24 ++-- opendelta/delta_models/lora_old.py | 127 ------------------ opendelta/delta_models/low_rank_adapter.py | 8 +- opendelta/delta_models/soft_prompt.py | 5 + opendelta/utils/cuda.py | 14 ++ 12 files changed, 135 insertions(+), 155 deletions(-) create mode 100644 examples/tutorial/2_with_bmtrain.py create mode 100644 examples/tutorial/2_with_bmtrain.sh delete mode 100644 opendelta/delta_models/lora_old.py diff --git a/examples/tutorial/2_with_bmtrain.py b/examples/tutorial/2_with_bmtrain.py new file mode 100644 index 0000000..7e35189 --- /dev/null +++ b/examples/tutorial/2_with_bmtrain.py @@ -0,0 +1,50 @@ +import bmtrain as bmt +import opendelta as od +from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel +import torch +import numpy +import random + +def manual_seed(seed): + torch.manual_seed(seed) + numpy.random.seed(seed) + random.seed(seed) + +from model_center.model import Bert, BertConfig +bmt.init_distributed() +config = BertConfig.from_pretrained("/yinxr/zwl/.cache/model_center/bert-base-uncased") +config.dropout_p = 0 +model = Bert.from_pretrained("/yinxr/zwl/.cache/model_center/bert-base-uncased", config) + +print("before modify") +od.Visualization(model).structure_graph() + +manual_seed(233) +delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k']) +# delta_model = AdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn']) +# delta_model = CompacterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn']) +# delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn']) +# delta_model = BitFitModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn', '[r](.*)layernorm(.*)']) + +print(delta_model.delta_modules) + +print("after modify") +delta_model.log() +# This will visualize the backbone after modification and other information. + +delta_model.freeze_module(exclude=["deltas"], set_state_dict=True) +print("after freeze") +delta_model.log() +# The set_state_dict=True will tell the method to change the state_dict of the backbone_model to maintaining only the trainable parts. + +manual_seed(233) +inp = torch.randint(0, 30000, (32, 128)).cuda() +length = torch.randint(0, 128, (32,)).cuda() +attention_mask = (torch.arange(inp.shape[1], device=inp.device)[None, :].repeat(inp.shape[0], 1) < length[:, None]) +out = model(inp, attention_mask=attention_mask, output_logits=True).logits +print(out) + +if bmt.rank() == 0: + torch.save(model.state_dict(), "test.pt") + ckpt = torch.load("test.pt") + print(ckpt.keys()) \ No newline at end of file diff --git a/examples/tutorial/2_with_bmtrain.sh b/examples/tutorial/2_with_bmtrain.sh new file mode 100644 index 0000000..0a5b3bc --- /dev/null +++ b/examples/tutorial/2_with_bmtrain.sh @@ -0,0 +1 @@ +python3 -m torch.distributed.launch --master_addr localhost --master_port 34123 --nproc_per_node $1 --nnodes 1 --node_rank 0 2_with_bmtrain.py diff --git a/opendelta/delta_models/adapter.py b/opendelta/delta_models/adapter.py index a1b821f..4202c5a 100644 --- a/opendelta/delta_models/adapter.py +++ b/opendelta/delta_models/adapter.py @@ -85,6 +85,11 @@ class AdapterLayer(nn.Module, InterFaceMixin): self.instantiated = True # initialize the weight, which is important for fast convergence and better performance. self.apply(self._init_weight) + try: + import bmtrain as bmt + self.modulelist = bmt.BMTrainModelWrapper(self.modulelist) + except: + pass def _init_weight(self, module): if isinstance(module, nn.Linear): diff --git a/opendelta/delta_models/bitfit.py b/opendelta/delta_models/bitfit.py index bfac44c..29c9194 100644 --- a/opendelta/delta_models/bitfit.py +++ b/opendelta/delta_models/bitfit.py @@ -2,6 +2,7 @@ from typing import Optional, Union from opendelta.utils.signature import get_arg_names_inside_func from opendelta.utils.name_based_addressing import * from opendelta.basemodel import DeltaBase, is_leaf_module +from opendelta.utils.cuda import get_device, get_dtype import torch.nn as nn import torch @@ -28,17 +29,24 @@ class BitFitConfig(BaseDeltaConfig): setattr(self, arg_name, locals()[arg_name]) class BiasLayer(nn.Module): - def __init__(self, init_method="zero"): + def __init__(self, init_method="zero", dtype=None, device=None): super().__init__() self.init_method=init_method self.instantiated = False + self.dtype = dtype + self.device = device def instantiate(self, hidden_dim): if self.init_method == "zero": - self.bias = nn.Parameter(torch.zeros(hidden_dim)) + self.bias = nn.Parameter(torch.zeros(hidden_dim, dtype=self.dtype, device=self.device)) else: raise NotImplementedError self.instantiated = True + try: + import bmtrain as bmt + self.bias = bmt.BMTrainModelWrapper(self.bias) + except: + pass def post_forward(self, output): r"""Presuming the first argument is the tensor to add bias along the last dimension. @@ -145,7 +153,7 @@ class BitFitModel(DeltaBase): ): if is_leaf_module(module): # if it is a leaf module, add bias to it regardless of its type. - if isinstance(module, nn.Linear): + if self.check_linear(module): self.add_bias_to_linear(module) else: # for example, layer_norms, lm_heads. @@ -153,34 +161,43 @@ class BitFitModel(DeltaBase): else: # for the non-leaf modules, by default it will add bias only to the linear submodules. for n, c in module.named_modules(): - if isinstance(c, nn.Linear): - if c.bias is None: - bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True) - c.register_parameter('bias', bias) - self._reset_bias_parameters(c) - self.delta_params.append(bias) - else: - c.bias.requires_grad = True - self.delta_params.append(c.bias) + if self.check_linear(c): + self.add_bias_to_linear(c) else: pass def add_bias_to_linear(self, c): if c.bias is None: bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True) - c.register_parameter('bias', bias) self._reset_bias_parameters(c) + try: + import bmtrain as bmt + bias = bmt.BMTrainModelWrapper(bias) + except: + pass + c.register_parameter('bias', bias) self.delta_params.append(bias) else: c.bias.requires_grad = True self.delta_params.append(c.bias) def add_bias_to_others(self, c): - new_bias = BiasLayer() + new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since # the name `bias` is reserved for some module such as roberta's LayerNorm. self.delta_modules.append(new_bias) + def check_linear(self, m): + if isinstance(m, nn.Linear): + return True + else: + try: + from model_center.layer import Linear + if isinstance(m, Linear): + return True + except: + pass + return False @staticmethod diff --git a/opendelta/delta_models/compacter.py b/opendelta/delta_models/compacter.py index 81c5484..86e2799 100644 --- a/opendelta/delta_models/compacter.py +++ b/opendelta/delta_models/compacter.py @@ -93,6 +93,13 @@ class HyperComplexAdapterLayer(nn.Module): phm_init_range=self.phm_init_range, kronecker_prod=self.kronecker_prod).to(self.device) self.instantiated = True + try: + import bmtrain as bmt + self.activation = bmt.BMTrainModelWrapper(self.activation) + self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler) + self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler) + except: + pass def post_forward(self, output): diff --git a/opendelta/delta_models/layers/hypercomplex_linear.py b/opendelta/delta_models/layers/hypercomplex_linear.py index e0ed589..acee0e8 100644 --- a/opendelta/delta_models/layers/hypercomplex_linear.py +++ b/opendelta/delta_models/layers/hypercomplex_linear.py @@ -62,7 +62,7 @@ def matvec_product(W: torch.Tensor, x: torch.Tensor, else: H = kronecker_product_einsum_batched(phm_rule, W).sum(0) - y = torch.matmul(input=x, other=H) + y = torch.matmul(input=x.to(H.dtype), other=H).to(x.dtype) if bias is not None: y += bias return y diff --git a/opendelta/delta_models/layers/low_rank_linear.py b/opendelta/delta_models/layers/low_rank_linear.py index 61ab92d..bb2b25d 100644 --- a/opendelta/delta_models/layers/low_rank_linear.py +++ b/opendelta/delta_models/layers/low_rank_linear.py @@ -33,7 +33,7 @@ class LowRankLinear(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: W = self.W_left*self.W_right - output = torch.matmul(input=x, other=W) + output = torch.matmul(input=x.to(W.dtype), other=W).to(x.dtype) if self.bias: output += self.b return output diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 2c172d6..3fe9504 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -137,15 +137,17 @@ class LoraModel(DeltaBase): pass def new_module_like(self, child_module): - if isinstance(child_module, nn.Linear): - in_features, out_features = child_module.in_features, child_module.out_features - new_module = LowRankLinear(in_features = in_features, - out_features = out_features, - weight = child_module.weight, - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout) - self.delta_modules.append(new_module) - else: - raise NotImplementedError + in_features, out_features = child_module.in_features, child_module.out_features + new_module = LowRankLinear(in_features = in_features, + out_features = out_features, + weight = child_module.weight, + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout) + try: + import bmtrain as bmt + new_module = bmt.BMTrainModelWrapper(new_module) + except: + pass + self.delta_modules.append(new_module) return new_module diff --git a/opendelta/delta_models/lora_old.py b/opendelta/delta_models/lora_old.py deleted file mode 100644 index e548fd4..0000000 --- a/opendelta/delta_models/lora_old.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional, Union - -from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func -from opendelta.utils.name_based_addressing import * -from opendelta.basemodel import DeltaBase -from transformers.models.t5 import T5ForConditionalGeneration -import loralib as lora -import torch.nn as nn -from opendelta import BaseDeltaConfig - -class LoraConfig(BaseDeltaConfig): - r""" - This is the configuration class to store the configuration of a :py:class:`~LoraModel` - - """ - def __init__( - self, - lora_r=8, - lora_alpha=16, - lora_dropout=0.0, - **kwargs - ): - super().__init__(**kwargs) - arg_names = get_arg_names_inside_func(self.__init__) - for arg_name in arg_names: - if not hasattr(self, arg_name): # the arg has not been registered in parent config - setattr(self, arg_name, locals()[arg_name]) - - -class LoraModel(DeltaBase): - r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models `_ . - Thanks for their `loralib `_, we use loralib.linear - to replace the linear layer of the backbone model. - - class attributes: - - default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the - attention layer. However, other linears can also be modified, and may lead to better performance. - - .. note:: - modified_modules should point to linear layer. We currently don't support broadcast to all linears in - a module's child modules. - - - delta_type = "lora" - - - Args: - backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified. - lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has. - lora_alpha (:obj:`int`, *optional*): A hyper-parameter to control the init scale of loralib.linear . - lora_dropout (:obj:`float`, *optional*): The dropout rate in lora.linear. - modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only - the implemented ones) - unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen - together with the prefix parameters. - common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping. - - """ - - config_class = LoraConfig - delta_type = "lora" - default_modified_modules = ['attn.q', 'attn.v'] - def __init__(self, - backbone_model: nn.Module, - lora_r=8, - lora_alpha=16, - lora_dropout=0.0, - modified_modules: Optional[List[str]] = None, - exclude_modules: Optional[List[str]] = None, - unfrozen_modules: Optional[List[str]] = None, - common_structure: Optional[bool] = None, - interactive_modify: Optional[Union[bool, int]] = False, - ): - DeltaBase.__init__(self, - backbone_model, - modified_modules=modified_modules, - exclude_modules=exclude_modules, - unfrozen_modules=unfrozen_modules, - common_structure=common_structure, - interactive_modify=interactive_modify, - ) - arg_names = get_arg_names_inside_func(self.__init__) - for arg_name in arg_names: - if not hasattr(self, arg_name): # not registered in parent class - setattr(self, arg_name, locals()[arg_name]) - - self.delta_modules = nn.ModuleList() - - self.add_all_delta_to_backbone(self.backbone_model, - self.modified_modules, - ) - - - - def update_module(self, module: nn.Module, key: str): - parent_ref, child_name, child_ref = self.find_module(module, key) - new_module = self.new_module_like(child_module=child_ref) - self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora") - - def _pseudo_data_to_instantiate(self, module): - # no need to pass pseudo input, so overwrite it - pass - - def new_module_like(self, child_module): - if isinstance(child_module, nn.Linear): - in_features, out_features = child_module.in_features, child_module.out_features - new_module = lora.Linear(in_features=in_features, - out_features=out_features, - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout) - new_module.weight = child_module.weight - new_module.bias = child_module.bias # if bias is None, also copy - else: - raise NotImplementedError - return new_module - - - - def mark_as_delta(self, module: nn.Module = None): - if module is None: - module=self - for n, p in module.named_parameters(): - param_name = n.split(".")[-1] - if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter. - setattr(p, "_is_delta", True) - - diff --git a/opendelta/delta_models/low_rank_adapter.py b/opendelta/delta_models/low_rank_adapter.py index 2e4571e..210ade2 100644 --- a/opendelta/delta_models/low_rank_adapter.py +++ b/opendelta/delta_models/low_rank_adapter.py @@ -69,6 +69,13 @@ class LowRankAdapter(nn.Module): rank=self.low_rank_rank).to(self.device) self.instantiated = True + try: + import bmtrain as bmt + self.activation = bmt.BMTrainModelWrapper(self.activation) + self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler) + self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler) + except: + pass def post_forward(self, output): r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter, @@ -88,7 +95,6 @@ class LowRankAdapter(nn.Module): logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") self.instantiate(hidden_dim=self.hidden_dim) - z = self.down_sampler(hiddens) z = self.activation(z) adapter_output = self.up_sampler(z) diff --git a/opendelta/delta_models/soft_prompt.py b/opendelta/delta_models/soft_prompt.py index c682132..f62b46d 100644 --- a/opendelta/delta_models/soft_prompt.py +++ b/opendelta/delta_models/soft_prompt.py @@ -225,5 +225,10 @@ class SoftPromptModel(DeltaBase): init_range = self.init_range, device = module_device, ) + try: + import bmtrain as bmt + soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer) + except: + pass self.delta_modules.append(soft_prompt_layer) return soft_prompt_layer diff --git a/opendelta/utils/cuda.py b/opendelta/utils/cuda.py index 5f237a7..fcfbc10 100644 --- a/opendelta/utils/cuda.py +++ b/opendelta/utils/cuda.py @@ -17,6 +17,20 @@ def get_device(module : Union[nn.Module, nn.Parameter]): else: raise RuntimeError("The module is paralleled acrossed device, please get device in a inner module") +def get_dtype(module : Union[nn.Module, nn.Parameter]): + if not (isinstance(module, nn.Module) \ + or isinstance(module, nn.Parameter)): + raise RuntimeError("module is not a instance of torch.nn.Module") + if hasattr(module, 'dtype'): + return module.dtype + else: + params_dtypes = [p.dtype for p in module.parameters()] + if len(params_dtypes) == 0: + return None + elif len(set(params_dtypes))==1: + return params_dtypes[0] + else: + raise RuntimeError("The module has multiple dtype, please get device in a inner module") def move_dict_to_cuda(dict_of_tensor, device): for key in dict_of_tensor: