fix bitfit

This commit is contained in:
Achazwl 2022-11-20 02:19:25 +00:00
parent 4315b83c8e
commit 35e51713b6
3 changed files with 13 additions and 41 deletions

View File

@ -56,8 +56,7 @@ def get_model(args):
"WiC" : 2, "WiC" : 2,
} }
model = BertModel(args, num_types[args.dataset_name]) model = BertModel(args, num_types[args.dataset_name])
od.Visualization(model).structure_graph() # od.Visualization(model).structure_graph()
if args.delta_type == "lora": if args.delta_type == "lora":
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt') delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
@ -289,7 +288,7 @@ def main():
dataset = prepare_dataset( dataset = prepare_dataset(
args, args,
tokenizer, tokenizer,
f"{args.base_path}/down_data/superglue/", f"/yinxr/zwl/ModelCenter/down_data/superglue/",
args.dataset_name, args.dataset_name,
bmt.rank(), bmt.world_size(), bmt.rank(), bmt.world_size(),
) )

View File

@ -29,12 +29,13 @@ class BitFitConfig(BaseDeltaConfig):
setattr(self, arg_name, locals()[arg_name]) setattr(self, arg_name, locals()[arg_name])
class BiasLayer(nn.Module): class BiasLayer(nn.Module):
def __init__(self, init_method="zero", dtype=None, device=None): def __init__(self, init_method="zero", dtype=None, device=None, backend=None):
super().__init__() super().__init__()
self.init_method=init_method self.init_method=init_method
self.instantiated = False self.instantiated = False
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.backend = backend
def instantiate(self, hidden_dim): def instantiate(self, hidden_dim):
if self.init_method == "zero": if self.init_method == "zero":
@ -42,11 +43,9 @@ class BiasLayer(nn.Module):
else: else:
raise NotImplementedError raise NotImplementedError
self.instantiated = True self.instantiated = True
try: if self.backend == 'bmt':
import bmtrain as bmt import bmtrain as bmt
self.bias = bmt.BMTrainModelWrapper(self.bias) self.bias = bmt.BMTrainModelWrapper(self.bias)
except:
pass
def post_forward(self, output): def post_forward(self, output):
r"""Presuming the first argument is the tensor to add bias along the last dimension. r"""Presuming the first argument is the tensor to add bias along the last dimension.
@ -114,7 +113,7 @@ class BitFitModel(DeltaBase):
config_class = BitFitConfig config_class = BitFitConfig
delta_type = "bitfit" delta_type = "bitfit"
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer. default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
_supported_backends = ['hf'] _supported_backends = ['hf', 'bmt']
_need_pseudo_data = False _need_pseudo_data = False
def __init__(self, def __init__(self,
backbone_model: nn.Module, backbone_model: nn.Module,
@ -157,36 +156,14 @@ class BitFitModel(DeltaBase):
module: nn.Module, module: nn.Module,
): ):
if is_leaf_module(module): if is_leaf_module(module):
# if it is a leaf module, add bias to it regardless of its type.
# if self.check_linear(module):
# self.add_bias_to_linear(module)
if self.backend_mapping.check_type(module, 'linear') or \ if self.backend_mapping.check_type(module, 'linear') or \
self.backend_mapping.check_type(module, 'layer_norm'): self.backend_mapping.check_type(module, 'layer_norm'):
self.add_bias_to_modules_have_bias_or_known_type(module) self.add_bias_to_modules_have_bias_or_known_type(module)
else: else:
# for example, layer_norms, lm_heads.
self.add_bias_to_others(module) self.add_bias_to_others(module)
else: else:
for n, c in module.named_modules(): for n, c in module.named_modules():
self.add_bias_to_modules_have_bias_or_known_type(c) self.add_bias_to_modules_have_bias_or_known_type(c)
# if self.check_linear(c):
# self.add_bias_to_linear(c)
# else:
# pass
# def add_bias_to_linear(self, c):
# if c.bias is None:
# bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
# self._reset_bias_parameters(c)
# try:
# import bmtrain as bmt
# bias = bmt.BMTrainModelWrapper(bias)
# except:
# pass
# c.register_parameter('bias', bias)
# self.delta_params.append(bias)
# else:
# self.add_bias_to_modules_have_bias_or_known_type(c)
def add_bias_to_modules_have_bias_or_known_type(self, c): def add_bias_to_modules_have_bias_or_known_type(self, c):
'''If it has bias, unfreeze it. '''If it has bias, unfreeze it.
@ -200,7 +177,7 @@ class BitFitModel(DeltaBase):
self.backend_mapping.check_type(c, 'layer_norm'): self.backend_mapping.check_type(c, 'layer_norm'):
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True) bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
self._reset_bias_parameters(c) self._reset_bias_parameters(c, bias)
if self.backend == 'bmt': if self.backend == 'bmt':
import bmtrain as bmt import bmtrain as bmt
bias = bmt.BMTrainModelWrapper(bias) bias = bmt.BMTrainModelWrapper(bias)
@ -209,19 +186,17 @@ class BitFitModel(DeltaBase):
self.delta_params.append(bias) self.delta_params.append(bias)
def add_bias_to_others(self, c): def add_bias_to_others(self, c):
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain? new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c), backend=self.backend)
if self.backend == 'bmt':
import bmtrain as bmt
new_bias = bmt.BMTrainModelWrapper(new_bias)
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm. self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias) self.delta_modules.append(new_bias)
@staticmethod @staticmethod
def _reset_bias_parameters(linear_module): def _reset_bias_parameters(linear_module, bias):
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight) fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(linear_module.bias, -bound, bound) init.uniform_(bias, -bound, bound)
# init.uniform_(bias, -bound, bound)
def detach(self, module): def detach(self, module):
r"""Not implemented for BitFit yet. Please wait for the next version. r"""Not implemented for BitFit yet. Please wait for the next version.

View File

@ -161,7 +161,7 @@ class SoftPromptModel(DeltaBase):
config_class = SoftPromptConfig config_class = SoftPromptConfig
delta_type = "soft_prompt" delta_type = "soft_prompt"
default_modified_modules = ["root"] # not used default_modified_modules = ["root"] # not used
_supported_backends = ['hf'] #'bmt'] _supported_backends = ['hf', 'bmt']
_need_pseudo_data = False _need_pseudo_data = False
def __init__(self, def __init__(self,
backbone_model: nn.Module, backbone_model: nn.Module,
@ -223,10 +223,8 @@ class SoftPromptModel(DeltaBase):
init_range = self.init_range, init_range = self.init_range,
device = module_device, device = module_device,
) )
try: if self.backend == 'bmt':
import bmtrain as bmt import bmtrain as bmt
soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer) soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
except:
pass
self.delta_modules.append(soft_prompt_layer) self.delta_modules.append(soft_prompt_layer)
return soft_prompt_layer return soft_prompt_layer