regex docs

This commit is contained in:
Achazwl 2022-02-15 22:43:28 +08:00
parent b64cb5f145
commit 0c590e7965
4 changed files with 18 additions and 47 deletions

View File

@ -97,7 +97,15 @@ Handcrafting the full names of submodules can be frustrating. We made some simpl
2. Regular Expression. 2. Regular Expression.
<img src="../imgs/todo-icon.jpeg" height="30px"> Unit test and Doc later.
We also support regex end-matching rules.
We use a beginning `[r]` followed by a regular expression to represent this rule, where `[r]` is used to distinguish it from normal end-matching rules and has no other meanings.
Taking RoBERTa with an classifier on top as an example: It has two modules named `roberta.encoder.layer.0.attention.output.dense` and `roberta.encoder.layer.0.output.dense`, which both end up with `output.dense`. To distinguish them:
- set `'[r](\d)+\.output.dense'` using regex rules, where `(\d)+` match any layer numbers. This rule only match `roberta.encoder.layer.0.output.dense`.
- set `'attention.output.dense'` using ordinary rules, which only match `roberta.encoder.layer.0.attention.output.dense`.
3. Interactive Selection. 3. Interactive Selection.

View File

@ -8,7 +8,7 @@ Visualization(model).structure_graph()
from opendelta import LoraModel from opendelta import LoraModel
import re import re
delta_model = LoraModel(backbone_model=model, modified_modules=[re.compile('(\d)+\.output.dense'), 'attention.output.dense']) delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
print("after modify") print("after modify")
delta_model.log() delta_model.log()
# This will visualize the backbone after modification and other information. # This will visualize the backbone after modification and other information.

View File

@ -262,14 +262,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if is_leaf_module(module): if is_leaf_module(module):
for n, p in module.named_parameters(): for n, p in module.named_parameters():
if self.find_key(".".join([prefix,n]), exclude, only_tail=True): if self.find_key(".".join([prefix,n]), exclude):
continue continue
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))): if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False p.requires_grad = False
return return
else: else:
for n, c in module.named_children(): for n, c in module.named_children():
if self.find_key(".".join([prefix,n]), exclude, only_tail=True): # if found, untouch the parameters if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
continue continue
else: # firstly freeze the non module params, then go deeper. else: # firstly freeze the non module params, then go deeper.
params = non_module_param(module) params = non_module_param(module)
@ -282,14 +282,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
def find_key(self, key: str, target_list: List[Union[str, re.Pattern]], only_tail=True): def find_key(self, key: str, target_list: List[str]):
r"""Check whether any target string is in the key or in the tail of the key, i.e., r"""Check whether any target string is in the key or in the tail of the key, i.e.,
Args: Args:
key (:obj:`str`): The key (name) of a submodule in a ancestor module. key (:obj:`str`): The key (name) of a submodule in a ancestor module.
E.g., model.encoder.layer.0.attention E.g., model.encoder.layer.0.attention
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"] target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
only_tail (:obj:`bool`): the element in the target_list should be in the tail of key
Returns: Returns:
:obj:`bool` True if the key matchs the target list. :obj:`bool` True if the key matchs the target list.
@ -299,10 +298,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if not key: if not key:
return False return False
try: try:
if only_tail: return endswith_in(key, target_list)
return endswith_in(key, target_list)
else:
return substring_in(key, target_list)
except: except:
from IPython import embed from IPython import embed
embed(header = "find_key exception") embed(header = "find_key exception")

View File

@ -16,13 +16,9 @@ def is_child_key(str_a: str , list_b: List[str]):
""" """
return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b) return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b)
def endswith_in(str_a: str, list_b: List[Union[str, re.Pattern]]): def endswith_in(str_a: str, list_b: List[str]):
return endswith_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \ return endswith_in_regex(str_a, [b[3:] for b in list_b if b.startswith("[r]")]) or \
endswith_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)]) endswith_in_normal(str_a, [b for b in list_b if not b.startswith("[r]")])
def substring_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
return substring_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
substring_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
def endswith_in_normal(str_a: str , list_b: List[str]): def endswith_in_normal(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b. r"""check whether ``str_a`` has a substring that is in list_b.
@ -32,20 +28,6 @@ def endswith_in_normal(str_a: str , list_b: List[str]):
""" """
return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b) return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b)
def substring_in_normal(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b.
Args:
Returns:
"""
token_a = str_a.split(".")
for str_b in list_b:
token_b = str_b.split(".")
for i in range(len(token_a)-len(token_b)+1):
if "".join(token_a[i:i+len(token_b)]) == "".join(token_b):
return True
return False
def endswith_in_regex(str_a: str , list_b: List[str]): def endswith_in_regex(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b. r"""check whether ``str_a`` has a substring that is in list_b.
@ -53,7 +35,7 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
Returns: Returns:
""" """
for str_b in list_b: for str_b in list_b:
ret = re.search(str_b, str_a) ret = re.search(re.compile(str_b), str_a)
if ret is not None: if ret is not None:
b = ret.group() b = ret.group()
if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")): if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")):
@ -61,19 +43,4 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
return True return True
return False return False
def substring_in_regex(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b.
Args:
Returns:
"""
for str_b in list_b:
ret = re.search(str_b, str_a)
if ret is not None:
b = ret.group()
if (ret.span()[0] == 0 or str_a[ret.span()[0]-1] == ".") and \
(ret.span()[1] == len(str_a) or str_a[ret.span()[1]] == "."): #and b == str_a and (str_a==b or str_a[-len(b)-1] == "."):
# the latter is to judge whether it is a full sub key in the str_a, e.g. str_a=`attn.c_attn` and list_b=[`attn`] will given False
return True
return False