Merge pull request #32 from thunlp/fix-exclude

fix exclude
This commit is contained in:
DingDing 2022-10-16 21:42:18 +08:00 committed by GitHub
commit 085d388102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 3 deletions

View File

@ -288,21 +288,23 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if is_leaf_module(module):
for n, p in module.named_parameters():
if self.find_key(".".join([prefix,n]), exclude):
next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude):
continue
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
return
else:
for n, c in module.named_children():
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
continue
else: # firstly freeze the non module params, then go deeper.
params = non_module_param(module)
for n, p in params:
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
self._freeze_module_recursive(c, exclude=exclude, prefix=".".join([prefix,n]) )
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)