regex docs
This commit is contained in:
parent
b64cb5f145
commit
0c590e7965
|
@ -97,7 +97,15 @@ Handcrafting the full names of submodules can be frustrating. We made some simpl
|
||||||
|
|
||||||
|
|
||||||
2. Regular Expression.
|
2. Regular Expression.
|
||||||
<img src="../imgs/todo-icon.jpeg" height="30px"> Unit test and Doc later.
|
|
||||||
|
We also support regex end-matching rules.
|
||||||
|
We use a beginning `[r]` followed by a regular expression to represent this rule, where `[r]` is used to distinguish it from normal end-matching rules and has no other meanings.
|
||||||
|
|
||||||
|
Taking RoBERTa with an classifier on top as an example: It has two modules named `roberta.encoder.layer.0.attention.output.dense` and `roberta.encoder.layer.0.output.dense`, which both end up with `output.dense`. To distinguish them:
|
||||||
|
|
||||||
|
- set `'[r](\d)+\.output.dense'` using regex rules, where `(\d)+` match any layer numbers. This rule only match `roberta.encoder.layer.0.output.dense`.
|
||||||
|
|
||||||
|
- set `'attention.output.dense'` using ordinary rules, which only match `roberta.encoder.layer.0.attention.output.dense`.
|
||||||
|
|
||||||
3. Interactive Selection.
|
3. Interactive Selection.
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ Visualization(model).structure_graph()
|
||||||
|
|
||||||
from opendelta import LoraModel
|
from opendelta import LoraModel
|
||||||
import re
|
import re
|
||||||
delta_model = LoraModel(backbone_model=model, modified_modules=[re.compile('(\d)+\.output.dense'), 'attention.output.dense'])
|
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
|
||||||
print("after modify")
|
print("after modify")
|
||||||
delta_model.log()
|
delta_model.log()
|
||||||
# This will visualize the backbone after modification and other information.
|
# This will visualize the backbone after modification and other information.
|
|
@ -262,14 +262,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
|
|
||||||
if is_leaf_module(module):
|
if is_leaf_module(module):
|
||||||
for n, p in module.named_parameters():
|
for n, p in module.named_parameters():
|
||||||
if self.find_key(".".join([prefix,n]), exclude, only_tail=True):
|
if self.find_key(".".join([prefix,n]), exclude):
|
||||||
continue
|
continue
|
||||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
for n, c in module.named_children():
|
for n, c in module.named_children():
|
||||||
if self.find_key(".".join([prefix,n]), exclude, only_tail=True): # if found, untouch the parameters
|
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
|
||||||
continue
|
continue
|
||||||
else: # firstly freeze the non module params, then go deeper.
|
else: # firstly freeze the non module params, then go deeper.
|
||||||
params = non_module_param(module)
|
params = non_module_param(module)
|
||||||
|
@ -282,14 +282,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def find_key(self, key: str, target_list: List[Union[str, re.Pattern]], only_tail=True):
|
def find_key(self, key: str, target_list: List[str]):
|
||||||
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
|
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
|
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
|
||||||
E.g., model.encoder.layer.0.attention
|
E.g., model.encoder.layer.0.attention
|
||||||
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
|
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
|
||||||
only_tail (:obj:`bool`): the element in the target_list should be in the tail of key
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`bool` True if the key matchs the target list.
|
:obj:`bool` True if the key matchs the target list.
|
||||||
|
@ -299,10 +298,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
if not key:
|
if not key:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
if only_tail:
|
return endswith_in(key, target_list)
|
||||||
return endswith_in(key, target_list)
|
|
||||||
else:
|
|
||||||
return substring_in(key, target_list)
|
|
||||||
except:
|
except:
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
embed(header = "find_key exception")
|
embed(header = "find_key exception")
|
||||||
|
|
|
@ -16,13 +16,9 @@ def is_child_key(str_a: str , list_b: List[str]):
|
||||||
"""
|
"""
|
||||||
return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b)
|
return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b)
|
||||||
|
|
||||||
def endswith_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
|
def endswith_in(str_a: str, list_b: List[str]):
|
||||||
return endswith_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
|
return endswith_in_regex(str_a, [b[3:] for b in list_b if b.startswith("[r]")]) or \
|
||||||
endswith_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
|
endswith_in_normal(str_a, [b for b in list_b if not b.startswith("[r]")])
|
||||||
|
|
||||||
def substring_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
|
|
||||||
return substring_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
|
|
||||||
substring_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
|
|
||||||
|
|
||||||
def endswith_in_normal(str_a: str , list_b: List[str]):
|
def endswith_in_normal(str_a: str , list_b: List[str]):
|
||||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||||
|
@ -32,20 +28,6 @@ def endswith_in_normal(str_a: str , list_b: List[str]):
|
||||||
"""
|
"""
|
||||||
return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b)
|
return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b)
|
||||||
|
|
||||||
def substring_in_normal(str_a: str , list_b: List[str]):
|
|
||||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
token_a = str_a.split(".")
|
|
||||||
for str_b in list_b:
|
|
||||||
token_b = str_b.split(".")
|
|
||||||
for i in range(len(token_a)-len(token_b)+1):
|
|
||||||
if "".join(token_a[i:i+len(token_b)]) == "".join(token_b):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def endswith_in_regex(str_a: str , list_b: List[str]):
|
def endswith_in_regex(str_a: str , list_b: List[str]):
|
||||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||||
|
|
||||||
|
@ -53,7 +35,7 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
|
||||||
Returns:
|
Returns:
|
||||||
"""
|
"""
|
||||||
for str_b in list_b:
|
for str_b in list_b:
|
||||||
ret = re.search(str_b, str_a)
|
ret = re.search(re.compile(str_b), str_a)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
b = ret.group()
|
b = ret.group()
|
||||||
if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")):
|
if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")):
|
||||||
|
@ -61,19 +43,4 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def substring_in_regex(str_a: str , list_b: List[str]):
|
|
||||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
for str_b in list_b:
|
|
||||||
ret = re.search(str_b, str_a)
|
|
||||||
if ret is not None:
|
|
||||||
b = ret.group()
|
|
||||||
if (ret.span()[0] == 0 or str_a[ret.span()[0]-1] == ".") and \
|
|
||||||
(ret.span()[1] == len(str_a) or str_a[ret.span()[1]] == "."): #and b == str_a and (str_a==b or str_a[-len(b)-1] == "."):
|
|
||||||
# the latter is to judge whether it is a full sub key in the str_a, e.g. str_a=`attn.c_attn` and list_b=[`attn`] will given False
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
Loading…
Reference in New Issue