regex docs
This commit is contained in:
parent
b64cb5f145
commit
0c590e7965
|
@ -97,7 +97,15 @@ Handcrafting the full names of submodules can be frustrating. We made some simpl
|
|||
|
||||
|
||||
2. Regular Expression.
|
||||
<img src="../imgs/todo-icon.jpeg" height="30px"> Unit test and Doc later.
|
||||
|
||||
We also support regex end-matching rules.
|
||||
We use a beginning `[r]` followed by a regular expression to represent this rule, where `[r]` is used to distinguish it from normal end-matching rules and has no other meanings.
|
||||
|
||||
Taking RoBERTa with an classifier on top as an example: It has two modules named `roberta.encoder.layer.0.attention.output.dense` and `roberta.encoder.layer.0.output.dense`, which both end up with `output.dense`. To distinguish them:
|
||||
|
||||
- set `'[r](\d)+\.output.dense'` using regex rules, where `(\d)+` match any layer numbers. This rule only match `roberta.encoder.layer.0.output.dense`.
|
||||
|
||||
- set `'attention.output.dense'` using ordinary rules, which only match `roberta.encoder.layer.0.attention.output.dense`.
|
||||
|
||||
3. Interactive Selection.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ Visualization(model).structure_graph()
|
|||
|
||||
from opendelta import LoraModel
|
||||
import re
|
||||
delta_model = LoraModel(backbone_model=model, modified_modules=[re.compile('(\d)+\.output.dense'), 'attention.output.dense'])
|
||||
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
|
||||
print("after modify")
|
||||
delta_model.log()
|
||||
# This will visualize the backbone after modification and other information.
|
|
@ -262,14 +262,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
if is_leaf_module(module):
|
||||
for n, p in module.named_parameters():
|
||||
if self.find_key(".".join([prefix,n]), exclude, only_tail=True):
|
||||
if self.find_key(".".join([prefix,n]), exclude):
|
||||
continue
|
||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||
p.requires_grad = False
|
||||
return
|
||||
else:
|
||||
for n, c in module.named_children():
|
||||
if self.find_key(".".join([prefix,n]), exclude, only_tail=True): # if found, untouch the parameters
|
||||
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
|
||||
continue
|
||||
else: # firstly freeze the non module params, then go deeper.
|
||||
params = non_module_param(module)
|
||||
|
@ -282,14 +282,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
|
||||
|
||||
def find_key(self, key: str, target_list: List[Union[str, re.Pattern]], only_tail=True):
|
||||
def find_key(self, key: str, target_list: List[str]):
|
||||
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
|
||||
|
||||
Args:
|
||||
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
|
||||
E.g., model.encoder.layer.0.attention
|
||||
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
|
||||
only_tail (:obj:`bool`): the element in the target_list should be in the tail of key
|
||||
|
||||
Returns:
|
||||
:obj:`bool` True if the key matchs the target list.
|
||||
|
@ -299,10 +298,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
if not key:
|
||||
return False
|
||||
try:
|
||||
if only_tail:
|
||||
return endswith_in(key, target_list)
|
||||
else:
|
||||
return substring_in(key, target_list)
|
||||
return endswith_in(key, target_list)
|
||||
except:
|
||||
from IPython import embed
|
||||
embed(header = "find_key exception")
|
||||
|
|
|
@ -16,13 +16,9 @@ def is_child_key(str_a: str , list_b: List[str]):
|
|||
"""
|
||||
return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b)
|
||||
|
||||
def endswith_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
|
||||
return endswith_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
|
||||
endswith_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
|
||||
|
||||
def substring_in(str_a: str, list_b: List[Union[str, re.Pattern]]):
|
||||
return substring_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \
|
||||
substring_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)])
|
||||
def endswith_in(str_a: str, list_b: List[str]):
|
||||
return endswith_in_regex(str_a, [b[3:] for b in list_b if b.startswith("[r]")]) or \
|
||||
endswith_in_normal(str_a, [b for b in list_b if not b.startswith("[r]")])
|
||||
|
||||
def endswith_in_normal(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
@ -32,20 +28,6 @@ def endswith_in_normal(str_a: str , list_b: List[str]):
|
|||
"""
|
||||
return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b)
|
||||
|
||||
def substring_in_normal(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
"""
|
||||
token_a = str_a.split(".")
|
||||
for str_b in list_b:
|
||||
token_b = str_b.split(".")
|
||||
for i in range(len(token_a)-len(token_b)+1):
|
||||
if "".join(token_a[i:i+len(token_b)]) == "".join(token_b):
|
||||
return True
|
||||
return False
|
||||
|
||||
def endswith_in_regex(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
|
@ -53,7 +35,7 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
|
|||
Returns:
|
||||
"""
|
||||
for str_b in list_b:
|
||||
ret = re.search(str_b, str_a)
|
||||
ret = re.search(re.compile(str_b), str_a)
|
||||
if ret is not None:
|
||||
b = ret.group()
|
||||
if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")):
|
||||
|
@ -61,19 +43,4 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
|
|||
return True
|
||||
return False
|
||||
|
||||
def substring_in_regex(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
"""
|
||||
for str_b in list_b:
|
||||
ret = re.search(str_b, str_a)
|
||||
if ret is not None:
|
||||
b = ret.group()
|
||||
if (ret.span()[0] == 0 or str_a[ret.span()[0]-1] == ".") and \
|
||||
(ret.span()[1] == len(str_a) or str_a[ret.span()[1]] == "."): #and b == str_a and (str_a==b or str_a[-len(b)-1] == "."):
|
||||
# the latter is to judge whether it is a full sub key in the str_a, e.g. str_a=`attn.c_attn` and list_b=[`attn`] will given False
|
||||
return True
|
||||
return False
|
||||
|
Loading…
Reference in New Issue