fix basemodel.py
fix freeze error when applying adapter on input embedding
This commit is contained in:
parent
1d92bb2d2a
commit
77a9ad3c8b
|
@ -306,15 +306,16 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
p.requires_grad = False
|
||||
return
|
||||
else:
|
||||
for n, c in module.named_children():
|
||||
next_prefix = n if prefix == "" else ".".join([prefix,n])
|
||||
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
|
||||
continue
|
||||
else: # firstly freeze the non module params, then go deeper.
|
||||
# firstly freeze the non module params, then go deeper.
|
||||
params = non_module_param(module)
|
||||
for n, p in params:
|
||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||
p.requires_grad = False
|
||||
for n, c in module.named_children():
|
||||
next_prefix = n if prefix == "" else ".".join([prefix,n])
|
||||
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
|
||||
continue
|
||||
else:
|
||||
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)
|
||||
|
||||
|
||||
|
@ -642,7 +643,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
"""
|
||||
def _caller(_org_func, includes, *args, **kwargs):
|
||||
state_dict = _org_func(*args, **kwargs)
|
||||
keys = list(state_dict.keys())
|
||||
keys = list(state_dict.keys())3
|
||||
for n in keys:
|
||||
if n not in includes:
|
||||
state_dict.pop(n)
|
||||
|
|
Loading…
Reference in New Issue