diff --git a/examples/tutorial/2_with_bmtrain.py b/examples/tutorial/2_with_bmtrain.py index 6cacc9c..072a0cb 100644 --- a/examples/tutorial/2_with_bmtrain.py +++ b/examples/tutorial/2_with_bmtrain.py @@ -56,8 +56,7 @@ def get_model(args): "WiC" : 2, } model = BertModel(args, num_types[args.dataset_name]) - od.Visualization(model).structure_graph() - + # od.Visualization(model).structure_graph() if args.delta_type == "lora": delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt') @@ -289,7 +288,7 @@ def main(): dataset = prepare_dataset( args, tokenizer, - f"{args.base_path}/down_data/superglue/", + f"/yinxr/zwl/ModelCenter/down_data/superglue/", args.dataset_name, bmt.rank(), bmt.world_size(), ) diff --git a/opendelta/delta_models/bitfit.py b/opendelta/delta_models/bitfit.py index 9d89548..221d84d 100644 --- a/opendelta/delta_models/bitfit.py +++ b/opendelta/delta_models/bitfit.py @@ -29,12 +29,13 @@ class BitFitConfig(BaseDeltaConfig): setattr(self, arg_name, locals()[arg_name]) class BiasLayer(nn.Module): - def __init__(self, init_method="zero", dtype=None, device=None): + def __init__(self, init_method="zero", dtype=None, device=None, backend=None): super().__init__() self.init_method=init_method self.instantiated = False self.dtype = dtype self.device = device + self.backend = backend def instantiate(self, hidden_dim): if self.init_method == "zero": @@ -42,11 +43,9 @@ class BiasLayer(nn.Module): else: raise NotImplementedError self.instantiated = True - try: + if self.backend == 'bmt': import bmtrain as bmt self.bias = bmt.BMTrainModelWrapper(self.bias) - except: - pass def post_forward(self, output): r"""Presuming the first argument is the tensor to add bias along the last dimension. @@ -114,7 +113,7 @@ class BitFitModel(DeltaBase): config_class = BitFitConfig delta_type = "bitfit" default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer. - _supported_backends = ['hf'] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = False def __init__(self, backbone_model: nn.Module, @@ -157,36 +156,14 @@ class BitFitModel(DeltaBase): module: nn.Module, ): if is_leaf_module(module): - # if it is a leaf module, add bias to it regardless of its type. - # if self.check_linear(module): - # self.add_bias_to_linear(module) if self.backend_mapping.check_type(module, 'linear') or \ self.backend_mapping.check_type(module, 'layer_norm'): self.add_bias_to_modules_have_bias_or_known_type(module) else: - # for example, layer_norms, lm_heads. self.add_bias_to_others(module) else: for n, c in module.named_modules(): self.add_bias_to_modules_have_bias_or_known_type(c) - # if self.check_linear(c): - # self.add_bias_to_linear(c) - # else: - # pass - - # def add_bias_to_linear(self, c): - # if c.bias is None: - # bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True) - # self._reset_bias_parameters(c) - # try: - # import bmtrain as bmt - # bias = bmt.BMTrainModelWrapper(bias) - # except: - # pass - # c.register_parameter('bias', bias) - # self.delta_params.append(bias) - # else: - # self.add_bias_to_modules_have_bias_or_known_type(c) def add_bias_to_modules_have_bias_or_known_type(self, c): '''If it has bias, unfreeze it. @@ -200,7 +177,7 @@ class BitFitModel(DeltaBase): self.backend_mapping.check_type(c, 'layer_norm'): bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True) - self._reset_bias_parameters(c) + self._reset_bias_parameters(c, bias) if self.backend == 'bmt': import bmtrain as bmt bias = bmt.BMTrainModelWrapper(bias) @@ -209,19 +186,17 @@ class BitFitModel(DeltaBase): self.delta_params.append(bias) def add_bias_to_others(self, c): - new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain? - if self.backend == 'bmt': - import bmtrain as bmt - new_bias = bmt.BMTrainModelWrapper(new_bias) + new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c), backend=self.backend) self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm. self.delta_modules.append(new_bias) @staticmethod - def _reset_bias_parameters(linear_module): + def _reset_bias_parameters(linear_module, bias): fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - init.uniform_(linear_module.bias, -bound, bound) + init.uniform_(bias, -bound, bound) + # init.uniform_(bias, -bound, bound) def detach(self, module): r"""Not implemented for BitFit yet. Please wait for the next version. diff --git a/opendelta/delta_models/soft_prompt.py b/opendelta/delta_models/soft_prompt.py index 6453368..b7d7692 100644 --- a/opendelta/delta_models/soft_prompt.py +++ b/opendelta/delta_models/soft_prompt.py @@ -161,7 +161,7 @@ class SoftPromptModel(DeltaBase): config_class = SoftPromptConfig delta_type = "soft_prompt" default_modified_modules = ["root"] # not used - _supported_backends = ['hf'] #'bmt'] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = False def __init__(self, backbone_model: nn.Module, @@ -223,10 +223,8 @@ class SoftPromptModel(DeltaBase): init_range = self.init_range, device = module_device, ) - try: + if self.backend == 'bmt': import bmtrain as bmt soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer) - except: - pass self.delta_modules.append(soft_prompt_layer) return soft_prompt_layer