From 77a9ad3c8bb21f6bc73031505942f1a425d64a15 Mon Sep 17 00:00:00 2001 From: William Date: Tue, 9 May 2023 10:17:19 +0800 Subject: [PATCH] fix basemodel.py fix freeze error when applying adapter on input embedding --- opendelta/basemodel.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index d5e335e..f060d88 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -306,15 +306,16 @@ class DeltaBase(nn.Module, SaveLoadMixin): p.requires_grad = False return else: + # firstly freeze the non module params, then go deeper. + params = non_module_param(module) + for n, p in params: + if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))): + p.requires_grad = False for n, c in module.named_children(): next_prefix = n if prefix == "" else ".".join([prefix,n]) if self.find_key(next_prefix, exclude): # if found, untouch the parameters continue - else: # firstly freeze the non module params, then go deeper. - params = non_module_param(module) - for n, p in params: - if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))): - p.requires_grad = False + else: self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix) @@ -642,7 +643,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): """ def _caller(_org_func, includes, *args, **kwargs): state_dict = _org_func(*args, **kwargs) - keys = list(state_dict.keys()) + keys = list(state_dict.keys())3 for n in keys: if n not in includes: state_dict.pop(n)