fix basemodel.py
fix freeze error when applying adapter on input embedding
This commit is contained in:
parent
1d92bb2d2a
commit
77a9ad3c8b
|
@ -306,15 +306,16 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
# firstly freeze the non module params, then go deeper.
|
||||||
|
params = non_module_param(module)
|
||||||
|
for n, p in params:
|
||||||
|
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||||
|
p.requires_grad = False
|
||||||
for n, c in module.named_children():
|
for n, c in module.named_children():
|
||||||
next_prefix = n if prefix == "" else ".".join([prefix,n])
|
next_prefix = n if prefix == "" else ".".join([prefix,n])
|
||||||
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
|
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
|
||||||
continue
|
continue
|
||||||
else: # firstly freeze the non module params, then go deeper.
|
else:
|
||||||
params = non_module_param(module)
|
|
||||||
for n, p in params:
|
|
||||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
|
||||||
p.requires_grad = False
|
|
||||||
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)
|
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@ -642,7 +643,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
"""
|
"""
|
||||||
def _caller(_org_func, includes, *args, **kwargs):
|
def _caller(_org_func, includes, *args, **kwargs):
|
||||||
state_dict = _org_func(*args, **kwargs)
|
state_dict = _org_func(*args, **kwargs)
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())3
|
||||||
for n in keys:
|
for n in keys:
|
||||||
if n not in includes:
|
if n not in includes:
|
||||||
state_dict.pop(n)
|
state_dict.pop(n)
|
||||||
|
|
Loading…
Reference in New Issue