fix bitfit

This commit is contained in:
Achazwl 2022-11-20 02:19:25 +00:00
parent 4315b83c8e
commit 35e51713b6
3 changed files with 13 additions and 41 deletions

View File

@ -56,8 +56,7 @@ def get_model(args):
"WiC" : 2,
}
model = BertModel(args, num_types[args.dataset_name])
od.Visualization(model).structure_graph()
# od.Visualization(model).structure_graph()
if args.delta_type == "lora":
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
@ -289,7 +288,7 @@ def main():
dataset = prepare_dataset(
args,
tokenizer,
f"{args.base_path}/down_data/superglue/",
f"/yinxr/zwl/ModelCenter/down_data/superglue/",
args.dataset_name,
bmt.rank(), bmt.world_size(),
)

View File

@ -29,12 +29,13 @@ class BitFitConfig(BaseDeltaConfig):
setattr(self, arg_name, locals()[arg_name])
class BiasLayer(nn.Module):
def __init__(self, init_method="zero", dtype=None, device=None):
def __init__(self, init_method="zero", dtype=None, device=None, backend=None):
super().__init__()
self.init_method=init_method
self.instantiated = False
self.dtype = dtype
self.device = device
self.backend = backend
def instantiate(self, hidden_dim):
if self.init_method == "zero":
@ -42,11 +43,9 @@ class BiasLayer(nn.Module):
else:
raise NotImplementedError
self.instantiated = True
try:
if self.backend == 'bmt':
import bmtrain as bmt
self.bias = bmt.BMTrainModelWrapper(self.bias)
except:
pass
def post_forward(self, output):
r"""Presuming the first argument is the tensor to add bias along the last dimension.
@ -114,7 +113,7 @@ class BitFitModel(DeltaBase):
config_class = BitFitConfig
delta_type = "bitfit"
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
_supported_backends = ['hf']
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@ -157,36 +156,14 @@ class BitFitModel(DeltaBase):
module: nn.Module,
):
if is_leaf_module(module):
# if it is a leaf module, add bias to it regardless of its type.
# if self.check_linear(module):
# self.add_bias_to_linear(module)
if self.backend_mapping.check_type(module, 'linear') or \
self.backend_mapping.check_type(module, 'layer_norm'):
self.add_bias_to_modules_have_bias_or_known_type(module)
else:
# for example, layer_norms, lm_heads.
self.add_bias_to_others(module)
else:
for n, c in module.named_modules():
self.add_bias_to_modules_have_bias_or_known_type(c)
# if self.check_linear(c):
# self.add_bias_to_linear(c)
# else:
# pass
# def add_bias_to_linear(self, c):
# if c.bias is None:
# bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
# self._reset_bias_parameters(c)
# try:
# import bmtrain as bmt
# bias = bmt.BMTrainModelWrapper(bias)
# except:
# pass
# c.register_parameter('bias', bias)
# self.delta_params.append(bias)
# else:
# self.add_bias_to_modules_have_bias_or_known_type(c)
def add_bias_to_modules_have_bias_or_known_type(self, c):
'''If it has bias, unfreeze it.
@ -200,7 +177,7 @@ class BitFitModel(DeltaBase):
self.backend_mapping.check_type(c, 'layer_norm'):
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
self._reset_bias_parameters(c)
self._reset_bias_parameters(c, bias)
if self.backend == 'bmt':
import bmtrain as bmt
bias = bmt.BMTrainModelWrapper(bias)
@ -209,19 +186,17 @@ class BitFitModel(DeltaBase):
self.delta_params.append(bias)
def add_bias_to_others(self, c):
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain?
if self.backend == 'bmt':
import bmtrain as bmt
new_bias = bmt.BMTrainModelWrapper(new_bias)
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c), backend=self.backend)
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)
@staticmethod
def _reset_bias_parameters(linear_module):
def _reset_bias_parameters(linear_module, bias):
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(linear_module.bias, -bound, bound)
init.uniform_(bias, -bound, bound)
# init.uniform_(bias, -bound, bound)
def detach(self, module):
r"""Not implemented for BitFit yet. Please wait for the next version.

View File

@ -161,7 +161,7 @@ class SoftPromptModel(DeltaBase):
config_class = SoftPromptConfig
delta_type = "soft_prompt"
default_modified_modules = ["root"] # not used
_supported_backends = ['hf'] #'bmt']
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@ -223,10 +223,8 @@ class SoftPromptModel(DeltaBase):
init_range = self.init_range,
device = module_device,
)
try:
if self.backend == 'bmt':
import bmtrain as bmt
soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
except:
pass
self.delta_modules.append(soft_prompt_layer)
return soft_prompt_layer