regex docs

This commit is contained in:
Achazwl 2022-02-15 22:43:28 +08:00
parent b64cb5f145
commit 0c590e7965
4 changed files with 18 additions and 47 deletions

View File

@ -97,7 +97,15 @@ Handcrafting the full names of submodules can be frustrating. We made some simpl
2. Regular Expression.
<img src="../imgs/todo-icon.jpeg" height="30px"> Unit test and Doc later.
We also support regex end-matching rules.
We use a beginning `[r]` followed by a regular expression to represent this rule, where `[r]` is used to distinguish it from normal end-matching rules and has no other meanings.
Taking RoBERTa with an classifier on top as an example: It has two modules named `roberta.encoder.layer.0.attention.output.dense` and `roberta.encoder.layer.0.output.dense`, which both end up with `output.dense`. To distinguish them:
- set `'[r](\d)+\.output.dense'` using regex rules, where `(\d)+` match any layer numbers. This rule only match `roberta.encoder.layer.0.output.dense`.
- set `'attention.output.dense'` using ordinary rules, which only match `roberta.encoder.layer.0.attention.output.dense`.
3. Interactive Selection.

View File

@ -8,7 +8,7 @@ Visualization(model).structure_graph()
from opendelta import LoraModel
import re
delta_model = LoraModel(backbone_model=model, modified_modules=[re.compile('(\d)+\.output.dense'), 'attention.output.dense'])
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
print("after modify")
delta_model.log()
# This will visualize the backbone after modification and other information.

View File

@ -262,14 +262,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if is_leaf_module(module):
for n, p in module.named_parameters():
if self.find_key(".".join([prefix,n]), exclude, only_tail=True):
if self.find_key(".".join([prefix,n]), exclude):
continue
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
return
else:
for n, c in module.named_children():
if self.find_key(".".join([prefix,n]), exclude, only_tail=True): # if found, untouch the parameters
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
continue
else: # firstly freeze the non module params, then go deeper.
params = non_module_param(module)
@ -282,14 +282,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
def find_key(self, key: str, target_list: List[Union[str, re.Pattern]], only_tail=True):
def find_key(self, key: str, target_list: List[str]):
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
Args:
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
E.g., model.encoder.layer.0.attention
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
only_tail (:obj:`bool`): the element in the target_list should be in the tail of key
Returns:
:obj:`bool` True if the key matchs the target list.
@ -299,10 +298,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if not key:
return False
try:
if only_tail:
return endswith_in(key, target_list)
else:
return substring_in(key, target_list)
except:
from IPython import embed
embed(header = "find_key exception")

View File

@ -16,13 +16,9 @@ def is_child_key(str_a: str , list_b: List[str]):
"""
return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b)
def endswith_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
return endswith_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
endswith_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
def substring_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
return substring_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
substring_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
def endswith_in(str_a: str, list_b: List[str]):
return endswith_in_regex(str_a, [b[3:] for b in list_b if b.startswith("[r]")]) or \
endswith_in_normal(str_a, [b for b in list_b if not b.startswith("[r]")])
def endswith_in_normal(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b.
@ -32,20 +28,6 @@ def endswith_in_normal(str_a: str , list_b: List[str]):
"""
return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b)
def substring_in_normal(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b.
Args:
Returns:
"""
token_a = str_a.split(".")
for str_b in list_b:
token_b = str_b.split(".")
for i in range(len(token_a)-len(token_b)+1):
if "".join(token_a[i:i+len(token_b)]) == "".join(token_b):
return True
return False
def endswith_in_regex(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b.
@ -53,7 +35,7 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
Returns:
"""
for str_b in list_b:
ret = re.search(str_b, str_a)
ret = re.search(re.compile(str_b), str_a)
if ret is not None:
b = ret.group()
if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")):
@ -61,19 +43,4 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
return True
return False
def substring_in_regex(str_a: str , list_b: List[str]):
r"""check whether ``str_a`` has a substring that is in list_b.
Args:
Returns:
"""
for str_b in list_b:
ret = re.search(str_b, str_a)
if ret is not None:
b = ret.group()
if (ret.span()[0] == 0 or str_a[ret.span()[0]-1] == ".") and \
(ret.span()[1] == len(str_a) or str_a[ret.span()[1]] == "."): #and b == str_a and (str_a==b or str_a[-len(b)-1] == "."):
# the latter is to judge whether it is a full sub key in the str_a, e.g. str_a=`attn.c_attn` and list_b=[`attn`] will given False
return True
return False