fix exclude prefix start with '.'
This commit is contained in:
parent
e4a0acff32
commit
c39a624ce8
|
@ -277,21 +277,23 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
if is_leaf_module(module):
|
||||
for n, p in module.named_parameters():
|
||||
if self.find_key(".".join([prefix,n]), exclude):
|
||||
next_prefix = n if prefix == "" else ".".join([prefix,n])
|
||||
if self.find_key(next_prefix, exclude):
|
||||
continue
|
||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||
p.requires_grad = False
|
||||
return
|
||||
else:
|
||||
for n, c in module.named_children():
|
||||
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
|
||||
next_prefix = n if prefix == "" else ".".join([prefix,n])
|
||||
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
|
||||
continue
|
||||
else: # firstly freeze the non module params, then go deeper.
|
||||
params = non_module_param(module)
|
||||
for n, p in params:
|
||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||
p.requires_grad = False
|
||||
self._freeze_module_recursive(c, exclude=exclude, prefix=".".join([prefix,n]) )
|
||||
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue