diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index 71f8d47..0f6cbc6 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -92,6 +92,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): default_exclude_modules = ["lm_head"] config_class = BaseDeltaConfig default_unfrozen_modules = ["deltas"] + pass_pseudo_data = True def __init__(self, backbone_model: nn.Module, modified_modules: Optional[List[str]] = None, @@ -200,7 +201,8 @@ class DeltaBase(nn.Module, SaveLoadMixin): if self.find_key(key, modified_modules): #TODO may have bugs when commonstructure has a virtual node and it's refered logger.debug("find key: {}".format(key)) self.update_module(backbone, key) - self._pseudo_data_to_instantiate(backbone) + if self.pass_pseudo_data: + self._pseudo_data_to_instantiate(backbone) # mark the paratmers that are the delta parameters for easily displaying the delta_paramters. self.mark_as_delta() return backbone @@ -329,7 +331,6 @@ class DeltaBase(nn.Module, SaveLoadMixin): module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model. """ - return if module is None: module = self.backbone_model device = get_device(module) diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 2c172d6..38c629a 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -97,6 +97,7 @@ class LoraModel(DeltaBase): config_class = LoraConfig delta_type = "lora" default_modified_modules = ['attn.q', 'attn.v'] + pass_pseudo_data = False def __init__(self, backbone_model: nn.Module, lora_r=8,