From 0c590e7965892585e63ebc800ed56867b926f20d Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Tue, 15 Feb 2022 22:43:28 +0800 Subject: [PATCH] regex docs --- docs/source/notes/keyfeature.md | 10 +++++- examples/tutorial/0_regex.py | 2 +- opendelta/basemodel.py | 12 +++---- opendelta/utils/name_based_addressing.py | 41 +++--------------------- 4 files changed, 18 insertions(+), 47 deletions(-) diff --git a/docs/source/notes/keyfeature.md b/docs/source/notes/keyfeature.md index b79367a..1e4fb9d 100644 --- a/docs/source/notes/keyfeature.md +++ b/docs/source/notes/keyfeature.md @@ -97,7 +97,15 @@ Handcrafting the full names of submodules can be frustrating. We made some simpl 2. Regular Expression. - Unit test and Doc later. + + We also support regex end-matching rules. + We use a beginning `[r]` followed by a regular expression to represent this rule, where `[r]` is used to distinguish it from normal end-matching rules and has no other meanings. + + Taking RoBERTa with an classifier on top as an example: It has two modules named `roberta.encoder.layer.0.attention.output.dense` and `roberta.encoder.layer.0.output.dense`, which both end up with `output.dense`. To distinguish them: + + - set `'[r](\d)+\.output.dense'` using regex rules, where `(\d)+` match any layer numbers. This rule only match `roberta.encoder.layer.0.output.dense`. + + - set `'attention.output.dense'` using ordinary rules, which only match `roberta.encoder.layer.0.attention.output.dense`. 3. Interactive Selection. diff --git a/examples/tutorial/0_regex.py b/examples/tutorial/0_regex.py index b58fc7f..642b920 100644 --- a/examples/tutorial/0_regex.py +++ b/examples/tutorial/0_regex.py @@ -8,7 +8,7 @@ Visualization(model).structure_graph() from opendelta import LoraModel import re -delta_model = LoraModel(backbone_model=model, modified_modules=[re.compile('(\d)+\.output.dense'), 'attention.output.dense']) +delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense']) print("after modify") delta_model.log() # This will visualize the backbone after modification and other information. \ No newline at end of file diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index e32301a..a10ca16 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -262,14 +262,14 @@ class DeltaBase(nn.Module, SaveLoadMixin): if is_leaf_module(module): for n, p in module.named_parameters(): - if self.find_key(".".join([prefix,n]), exclude, only_tail=True): + if self.find_key(".".join([prefix,n]), exclude): continue if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))): p.requires_grad = False return else: for n, c in module.named_children(): - if self.find_key(".".join([prefix,n]), exclude, only_tail=True): # if found, untouch the parameters + if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters continue else: # firstly freeze the non module params, then go deeper. params = non_module_param(module) @@ -282,14 +282,13 @@ class DeltaBase(nn.Module, SaveLoadMixin): - def find_key(self, key: str, target_list: List[Union[str, re.Pattern]], only_tail=True): + def find_key(self, key: str, target_list: List[str]): r"""Check whether any target string is in the key or in the tail of the key, i.e., Args: key (:obj:`str`): The key (name) of a submodule in a ancestor module. E.g., model.encoder.layer.0.attention target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"] - only_tail (:obj:`bool`): the element in the target_list should be in the tail of key Returns: :obj:`bool` True if the key matchs the target list. @@ -299,10 +298,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): if not key: return False try: - if only_tail: - return endswith_in(key, target_list) - else: - return substring_in(key, target_list) + return endswith_in(key, target_list) except: from IPython import embed embed(header = "find_key exception") diff --git a/opendelta/utils/name_based_addressing.py b/opendelta/utils/name_based_addressing.py index 689ca3e..e0d569f 100644 --- a/opendelta/utils/name_based_addressing.py +++ b/opendelta/utils/name_based_addressing.py @@ -16,13 +16,9 @@ def is_child_key(str_a: str , list_b: List[str]): """ return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b) -def endswith_in(str_a: str, list_b: List[Union[str, re.Pattern]]): - return endswith_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \ - endswith_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)]) - -def substring_in(str_a: str, list_b: List[Union[str, re.Pattern]]): - return substring_in_normal(str_a, [b for b in list_b if isinstance(b, str)]) or \ - substring_in_regex(str_a, [b for b in list_b if isinstance(b, re.Pattern)]) +def endswith_in(str_a: str, list_b: List[str]): + return endswith_in_regex(str_a, [b[3:] for b in list_b if b.startswith("[r]")]) or \ + endswith_in_normal(str_a, [b for b in list_b if not b.startswith("[r]")]) def endswith_in_normal(str_a: str , list_b: List[str]): r"""check whether ``str_a`` has a substring that is in list_b. @@ -32,20 +28,6 @@ def endswith_in_normal(str_a: str , list_b: List[str]): """ return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b) -def substring_in_normal(str_a: str , list_b: List[str]): - r"""check whether ``str_a`` has a substring that is in list_b. - - Args: - Returns: - """ - token_a = str_a.split(".") - for str_b in list_b: - token_b = str_b.split(".") - for i in range(len(token_a)-len(token_b)+1): - if "".join(token_a[i:i+len(token_b)]) == "".join(token_b): - return True - return False - def endswith_in_regex(str_a: str , list_b: List[str]): r"""check whether ``str_a`` has a substring that is in list_b. @@ -53,7 +35,7 @@ def endswith_in_regex(str_a: str , list_b: List[str]): Returns: """ for str_b in list_b: - ret = re.search(str_b, str_a) + ret = re.search(re.compile(str_b), str_a) if ret is not None: b = ret.group() if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")): @@ -61,19 +43,4 @@ def endswith_in_regex(str_a: str , list_b: List[str]): return True return False -def substring_in_regex(str_a: str , list_b: List[str]): - r"""check whether ``str_a`` has a substring that is in list_b. - - Args: - Returns: - """ - for str_b in list_b: - ret = re.search(str_b, str_a) - if ret is not None: - b = ret.group() - if (ret.span()[0] == 0 or str_a[ret.span()[0]-1] == ".") and \ - (ret.span()[1] == len(str_a) or str_a[ret.span()[1]] == "."): #and b == str_a and (str_a==b or str_a[-len(b)-1] == "."): - # the latter is to judge whether it is a full sub key in the str_a, e.g. str_a=`attn.c_attn` and list_b=[`attn`] will given False - return True - return False \ No newline at end of file