This commit is contained in:
HirasawaaYui 2022-10-10 03:23:14 +00:00
parent f6788bfc22
commit 2351259ecd
2 changed files with 4 additions and 2 deletions

View File

@ -92,6 +92,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
default_exclude_modules = ["lm_head"]
config_class = BaseDeltaConfig
default_unfrozen_modules = ["deltas"]
pass_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
modified_modules: Optional[List[str]] = None,
@ -200,7 +201,8 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if self.find_key(key, modified_modules): #TODO may have bugs when commonstructure has a virtual node and it's refered
logger.debug("find key: {}".format(key))
self.update_module(backbone, key)
self._pseudo_data_to_instantiate(backbone)
if self.pass_pseudo_data:
self._pseudo_data_to_instantiate(backbone)
# mark the paratmers that are the delta parameters for easily displaying the delta_paramters.
self.mark_as_delta()
return backbone
@ -329,7 +331,6 @@ class DeltaBase(nn.Module, SaveLoadMixin):
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
"""
return
if module is None:
module = self.backbone_model
device = get_device(module)

View File

@ -97,6 +97,7 @@ class LoraModel(DeltaBase):
config_class = LoraConfig
delta_type = "lora"
default_modified_modules = ['attn.q', 'attn.v']
pass_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
lora_r=8,