update #31
This commit is contained in:
parent
f6788bfc22
commit
2351259ecd
|
@ -92,6 +92,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
default_exclude_modules = ["lm_head"]
|
||||
config_class = BaseDeltaConfig
|
||||
default_unfrozen_modules = ["deltas"]
|
||||
pass_pseudo_data = True
|
||||
def __init__(self,
|
||||
backbone_model: nn.Module,
|
||||
modified_modules: Optional[List[str]] = None,
|
||||
|
@ -200,6 +201,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
if self.find_key(key, modified_modules): #TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||
logger.debug("find key: {}".format(key))
|
||||
self.update_module(backbone, key)
|
||||
if self.pass_pseudo_data:
|
||||
self._pseudo_data_to_instantiate(backbone)
|
||||
# mark the paratmers that are the delta parameters for easily displaying the delta_paramters.
|
||||
self.mark_as_delta()
|
||||
|
@ -329,7 +331,6 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
|
||||
|
||||
"""
|
||||
return
|
||||
if module is None:
|
||||
module = self.backbone_model
|
||||
device = get_device(module)
|
||||
|
|
|
@ -97,6 +97,7 @@ class LoraModel(DeltaBase):
|
|||
config_class = LoraConfig
|
||||
delta_type = "lora"
|
||||
default_modified_modules = ['attn.q', 'attn.v']
|
||||
pass_pseudo_data = False
|
||||
def __init__(self,
|
||||
backbone_model: nn.Module,
|
||||
lora_r=8,
|
||||
|
|
Loading…
Reference in New Issue