fix bitfit
This commit is contained in:
parent
4315b83c8e
commit
35e51713b6
|
@ -56,8 +56,7 @@ def get_model(args):
|
||||||
"WiC" : 2,
|
"WiC" : 2,
|
||||||
}
|
}
|
||||||
model = BertModel(args, num_types[args.dataset_name])
|
model = BertModel(args, num_types[args.dataset_name])
|
||||||
od.Visualization(model).structure_graph()
|
# od.Visualization(model).structure_graph()
|
||||||
|
|
||||||
|
|
||||||
if args.delta_type == "lora":
|
if args.delta_type == "lora":
|
||||||
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
|
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
|
||||||
|
@ -289,7 +288,7 @@ def main():
|
||||||
dataset = prepare_dataset(
|
dataset = prepare_dataset(
|
||||||
args,
|
args,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
f"{args.base_path}/down_data/superglue/",
|
f"/yinxr/zwl/ModelCenter/down_data/superglue/",
|
||||||
args.dataset_name,
|
args.dataset_name,
|
||||||
bmt.rank(), bmt.world_size(),
|
bmt.rank(), bmt.world_size(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,12 +29,13 @@ class BitFitConfig(BaseDeltaConfig):
|
||||||
setattr(self, arg_name, locals()[arg_name])
|
setattr(self, arg_name, locals()[arg_name])
|
||||||
|
|
||||||
class BiasLayer(nn.Module):
|
class BiasLayer(nn.Module):
|
||||||
def __init__(self, init_method="zero", dtype=None, device=None):
|
def __init__(self, init_method="zero", dtype=None, device=None, backend=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init_method=init_method
|
self.init_method=init_method
|
||||||
self.instantiated = False
|
self.instantiated = False
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
def instantiate(self, hidden_dim):
|
def instantiate(self, hidden_dim):
|
||||||
if self.init_method == "zero":
|
if self.init_method == "zero":
|
||||||
|
@ -42,11 +43,9 @@ class BiasLayer(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.instantiated = True
|
self.instantiated = True
|
||||||
try:
|
if self.backend == 'bmt':
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
self.bias = bmt.BMTrainModelWrapper(self.bias)
|
self.bias = bmt.BMTrainModelWrapper(self.bias)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def post_forward(self, output):
|
def post_forward(self, output):
|
||||||
r"""Presuming the first argument is the tensor to add bias along the last dimension.
|
r"""Presuming the first argument is the tensor to add bias along the last dimension.
|
||||||
|
@ -114,7 +113,7 @@ class BitFitModel(DeltaBase):
|
||||||
config_class = BitFitConfig
|
config_class = BitFitConfig
|
||||||
delta_type = "bitfit"
|
delta_type = "bitfit"
|
||||||
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
|
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
|
||||||
_supported_backends = ['hf']
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = False
|
_need_pseudo_data = False
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -157,36 +156,14 @@ class BitFitModel(DeltaBase):
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
):
|
):
|
||||||
if is_leaf_module(module):
|
if is_leaf_module(module):
|
||||||
# if it is a leaf module, add bias to it regardless of its type.
|
|
||||||
# if self.check_linear(module):
|
|
||||||
# self.add_bias_to_linear(module)
|
|
||||||
if self.backend_mapping.check_type(module, 'linear') or \
|
if self.backend_mapping.check_type(module, 'linear') or \
|
||||||
self.backend_mapping.check_type(module, 'layer_norm'):
|
self.backend_mapping.check_type(module, 'layer_norm'):
|
||||||
self.add_bias_to_modules_have_bias_or_known_type(module)
|
self.add_bias_to_modules_have_bias_or_known_type(module)
|
||||||
else:
|
else:
|
||||||
# for example, layer_norms, lm_heads.
|
|
||||||
self.add_bias_to_others(module)
|
self.add_bias_to_others(module)
|
||||||
else:
|
else:
|
||||||
for n, c in module.named_modules():
|
for n, c in module.named_modules():
|
||||||
self.add_bias_to_modules_have_bias_or_known_type(c)
|
self.add_bias_to_modules_have_bias_or_known_type(c)
|
||||||
# if self.check_linear(c):
|
|
||||||
# self.add_bias_to_linear(c)
|
|
||||||
# else:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# def add_bias_to_linear(self, c):
|
|
||||||
# if c.bias is None:
|
|
||||||
# bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
|
|
||||||
# self._reset_bias_parameters(c)
|
|
||||||
# try:
|
|
||||||
# import bmtrain as bmt
|
|
||||||
# bias = bmt.BMTrainModelWrapper(bias)
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
# c.register_parameter('bias', bias)
|
|
||||||
# self.delta_params.append(bias)
|
|
||||||
# else:
|
|
||||||
# self.add_bias_to_modules_have_bias_or_known_type(c)
|
|
||||||
|
|
||||||
def add_bias_to_modules_have_bias_or_known_type(self, c):
|
def add_bias_to_modules_have_bias_or_known_type(self, c):
|
||||||
'''If it has bias, unfreeze it.
|
'''If it has bias, unfreeze it.
|
||||||
|
@ -200,7 +177,7 @@ class BitFitModel(DeltaBase):
|
||||||
self.backend_mapping.check_type(c, 'layer_norm'):
|
self.backend_mapping.check_type(c, 'layer_norm'):
|
||||||
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
|
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
|
||||||
|
|
||||||
self._reset_bias_parameters(c)
|
self._reset_bias_parameters(c, bias)
|
||||||
if self.backend == 'bmt':
|
if self.backend == 'bmt':
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
bias = bmt.BMTrainModelWrapper(bias)
|
bias = bmt.BMTrainModelWrapper(bias)
|
||||||
|
@ -209,19 +186,17 @@ class BitFitModel(DeltaBase):
|
||||||
self.delta_params.append(bias)
|
self.delta_params.append(bias)
|
||||||
|
|
||||||
def add_bias_to_others(self, c):
|
def add_bias_to_others(self, c):
|
||||||
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain?
|
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c), backend=self.backend)
|
||||||
if self.backend == 'bmt':
|
|
||||||
import bmtrain as bmt
|
|
||||||
new_bias = bmt.BMTrainModelWrapper(new_bias)
|
|
||||||
|
|
||||||
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
|
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
|
||||||
self.delta_modules.append(new_bias)
|
self.delta_modules.append(new_bias)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_bias_parameters(linear_module):
|
def _reset_bias_parameters(linear_module, bias):
|
||||||
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
|
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
|
||||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||||
init.uniform_(linear_module.bias, -bound, bound)
|
init.uniform_(bias, -bound, bound)
|
||||||
|
# init.uniform_(bias, -bound, bound)
|
||||||
|
|
||||||
def detach(self, module):
|
def detach(self, module):
|
||||||
r"""Not implemented for BitFit yet. Please wait for the next version.
|
r"""Not implemented for BitFit yet. Please wait for the next version.
|
||||||
|
|
|
@ -161,7 +161,7 @@ class SoftPromptModel(DeltaBase):
|
||||||
config_class = SoftPromptConfig
|
config_class = SoftPromptConfig
|
||||||
delta_type = "soft_prompt"
|
delta_type = "soft_prompt"
|
||||||
default_modified_modules = ["root"] # not used
|
default_modified_modules = ["root"] # not used
|
||||||
_supported_backends = ['hf'] #'bmt']
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = False
|
_need_pseudo_data = False
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -223,10 +223,8 @@ class SoftPromptModel(DeltaBase):
|
||||||
init_range = self.init_range,
|
init_range = self.init_range,
|
||||||
device = module_device,
|
device = module_device,
|
||||||
)
|
)
|
||||||
try:
|
if self.backend == 'bmt':
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
|
soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
self.delta_modules.append(soft_prompt_layer)
|
self.delta_modules.append(soft_prompt_layer)
|
||||||
return soft_prompt_layer
|
return soft_prompt_layer
|
||||||
|
|
Loading…
Reference in New Issue