fix basemodel.py

fix freeze error when applying adapter on input embedding
This commit is contained in:
William 2023-05-09 10:17:19 +08:00 committed by GitHub
parent 1d92bb2d2a
commit 77a9ad3c8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 6 deletions

View File

@ -306,15 +306,16 @@ class DeltaBase(nn.Module, SaveLoadMixin):
p.requires_grad = False p.requires_grad = False
return return
else: else:
# firstly freeze the non module params, then go deeper.
params = non_module_param(module)
for n, p in params:
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
for n, c in module.named_children(): for n, c in module.named_children():
next_prefix = n if prefix == "" else ".".join([prefix,n]) next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude): # if found, untouch the parameters if self.find_key(next_prefix, exclude): # if found, untouch the parameters
continue continue
else: # firstly freeze the non module params, then go deeper. else:
params = non_module_param(module)
for n, p in params:
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix) self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)
@ -642,7 +643,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
""" """
def _caller(_org_func, includes, *args, **kwargs): def _caller(_org_func, includes, *args, **kwargs):
state_dict = _org_func(*args, **kwargs) state_dict = _org_func(*args, **kwargs)
keys = list(state_dict.keys()) keys = list(state_dict.keys())3
for n in keys: for n in keys:
if n not in includes: if n not in includes:
state_dict.pop(n) state_dict.pop(n)