fix exclude prefix start with '.'

This commit is contained in:
Achazwl 2022-10-10 01:27:06 +00:00
parent e4a0acff32
commit c39a624ce8
1 changed files with 5 additions and 3 deletions

View File

@ -277,21 +277,23 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if is_leaf_module(module):
for n, p in module.named_parameters():
if self.find_key(".".join([prefix,n]), exclude):
next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude):
continue
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
return
else:
for n, c in module.named_children():
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
continue
else: # firstly freeze the non module params, then go deeper.
params = non_module_param(module)
for n, p in params:
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
self._freeze_module_recursive(c, exclude=exclude, prefix=".".join([prefix,n]) )
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)