Merge branch 'dev' into main
This commit is contained in:
commit
1ac42c2240
|
@ -17,4 +17,6 @@ _build/
|
|||
outputs/
|
||||
log.txt
|
||||
**/DeltaHub/
|
||||
*beans
|
||||
*beans/
|
||||
**/examples/*/configs/
|
||||
!examples/*/configs/config_gen.py
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
![version](https://img.shields.io/badge/version-0.0.1-blue)
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
OpenDelta is a toolkit for parameter efficient methods (we dub it as *delta tuning*), by which users could flexibly assign (or add) a small amount parameters to update while keeping the most paramters frozen. By using OpenDelta, users could easily implement prefix-tuning, adapters, Lora, or any other types of delta tuning with preferred PTMs.
|
||||
|
@ -30,6 +31,9 @@ OpenDelta is a toolkit for parameter efficient methods (we dub it as *delta tuni
|
|||
- **A demo of using Opendelta to modify the PLM (E.g., BART).**
|
||||
![How PLM changes using Delta-tuning](docs/source/imgs/demo.gif)
|
||||
|
||||
## Updates
|
||||
- 2022.02.16 support [regular expression](docs/source/notes/namebasedaddr#regexexpr) in named-based addressing.
|
||||
|
||||
## Installation
|
||||
create a virtualenv (optional)
|
||||
```shell
|
||||
|
@ -60,7 +64,7 @@ cd OpenDelta
|
|||
python setup.py install
|
||||
```
|
||||
|
||||
#### Option 2: If you want to modify the code, run
|
||||
#### Option 2: If you want to modify the code or keep the repo updated by git clone, run
|
||||
```shell
|
||||
python setup.py develop
|
||||
```
|
||||
|
|
|
@ -20,6 +20,7 @@ OpenDelta is a **Plug-and-play** Library of the parameter-efficient fine-tuning
|
|||
notes/overview.md
|
||||
notes/installation.md
|
||||
notes/usage.md
|
||||
notes/namebasedaddr.md
|
||||
notes/visualization.md
|
||||
notes/saveload.md
|
||||
|
||||
|
|
|
@ -15,141 +15,9 @@ Here is how we achieve it.
|
|||
|
||||
<img src="../imgs/pointing-right-finger.png" height="30px"> **Read through it will also help you to implement your own delta models in a sustainable way.**
|
||||
|
||||
(namebasedaddr)=
|
||||
|
||||
## 1. Name-based submodule addressing.
|
||||
We locate the submodules that we want to apply a delta layer via name-based addressing.
|
||||
|
||||
In pytorch fashion, a submodule can be accessed from a root model via 'dot' addressing. For example, we define a toy language model
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
class MyNet1(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.name_a = nn.Linear(5,5)
|
||||
def forward(self, hiddens):
|
||||
return self.name_a(hiddens)
|
||||
|
||||
class MyNet2(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(10,5)
|
||||
self.name_b = nn.Sequential(MyNet1(), MyNet1())
|
||||
def forward(self, input_ids):
|
||||
hiddens = self.embedding(input_ids)
|
||||
return self.name_b(hiddens)
|
||||
|
||||
root = MyNet2()
|
||||
print(root.name_b[0].name_a)
|
||||
# Linear(in_features=5, out_features=5, bias=True)
|
||||
```
|
||||
|
||||
We can visualize the model (For details, see [visualization](visualization))
|
||||
|
||||
```python
|
||||
from opendelta import Visualization
|
||||
Visualization(root).structure_graph()
|
||||
```
|
||||
|
||||
````{collapse} <span style="color:rgb(141, 99, 224);font-weight:bold;font-style:italic">Click to view output</span>
|
||||
```{figure} ../imgs/name_based_addressing.png
|
||||
---
|
||||
width: 500px
|
||||
name: name_based_addressing
|
||||
---
|
||||
```
|
||||
````
|
||||
|
||||
In this case, string `"name_b.0.name_a"` will be the name to address the submodule from the root model.
|
||||
|
||||
Thus when applying a delta model to this toy net.
|
||||
|
||||
```
|
||||
from opendelta import AdapterModel
|
||||
AdapterModel(backbone_model=root, modified_modules=['name_b.0.name_a'])
|
||||
Visualization(root).structure_graph()
|
||||
```
|
||||
|
||||
````{collapse} <span style="color:rgb(141, 99, 224);font-weight:bold;font-style:italic">Click to view output</span>
|
||||
```{figure} ../imgs/toy-delta.png
|
||||
---
|
||||
width: 500px
|
||||
name: toy-delta
|
||||
---
|
||||
```
|
||||
````
|
||||
|
||||
### Makes addressing easier.
|
||||
|
||||
Handcrafting the full names of submodules can be frustrating. We made some simplifications
|
||||
|
||||
1. End-matching Rules.
|
||||
|
||||
OpenDelta will take every modules that
|
||||
**ends with** the provided name suffix as the modification [target module](target_module).
|
||||
:::{admonition} Example
|
||||
:class: tip
|
||||
Taking DistilBert with an classifier on top as an example:
|
||||
- set to `["0.attention.out_lin"]` will add delta modules to the attention output of distilbert's
|
||||
ayer 0, i.e., `distilbert.transformer.layer.0.attention.out_lin`.
|
||||
- set to `["attention.out_lin"]` will add the delta modules in every layer's `attention.out_lin`.
|
||||
:::
|
||||
|
||||
|
||||
2. Regular Expression.
|
||||
<img src="../imgs/todo-icon.jpeg" height="30px"> Unit test and Doc later.
|
||||
|
||||
3. Interactive Selection.
|
||||
|
||||
We provide a way to interact visually to select modules needed.
|
||||
|
||||
```python
|
||||
from transformers import BertForMaskedLM
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-cased")
|
||||
# suppose we load BERT
|
||||
|
||||
from opendelta import LoraModel # use lora as an example, others are same
|
||||
delta_model = LoraModel(backbone_model=model, interactive_modify=True)
|
||||
```
|
||||
|
||||
by setting `interactive_modify`, a web server will be opened on local host, and the link will be print in the terminal.
|
||||
|
||||
```
|
||||
http://0.0.0.0:8888/
|
||||
```
|
||||
|
||||
If on your local machine, click to open the link for interactive modification.
|
||||
|
||||
If on remote host, you could use port mapping. For example, vscode terminal will automatically do port mapping for you, you can simply use `control/command + click` to open the link.
|
||||
|
||||
You can change the port number in case the default port number is occupied by other program by setting `interactive_modify=port_number`, in which port_number is an integer.
|
||||
|
||||
The web page looks like the following figure.
|
||||
|
||||
```{figure} ../imgs/interact.jpg
|
||||
---
|
||||
width: 500px
|
||||
name: interact web page
|
||||
---
|
||||
```
|
||||
|
||||
- By clicking on `[+]`/`[-]` to expand / collapse tree nodes.
|
||||
|
||||
- By clicking on text to select tree nodes, **yellow dotted** box indicates the selection.
|
||||
|
||||
- **Double** click on the pink `[*]` is an advanced option to unfold the repeated nodes. By default, modules with the same architecture are folded into one node and are marked in red, for example, the `BertLayer` of layers 0~11 in the above figure are in the same structure. Regular model changes will make the same changes to each layers.
|
||||
|
||||
- If you want to change only a few of them, first double-click on `[*]`, then select the parts you want in the unfolded structure.
|
||||
|
||||
- If you want to make the same change to all but a few of them, first select the common parts you want in the folded structure, then double-click on `[*]` to remove the few positions you don't need to change in the expanded structure.
|
||||
|
||||
Click `submit` button on the top-right corner, then go back to your terminal, you can get a list of name-based addresses printed in the terminal in the following format, and these modules are being "delta".
|
||||
|
||||
```
|
||||
modified_modules:
|
||||
[bert.encoder.layer.0.output.dense, ..., bert.encoder.layer.11.output.dense]
|
||||
```
|
||||
|
||||
See [name based addressing](namebasedaddr)
|
||||
## 2. Three basic submodule-level delta operations.
|
||||
We use three key functions to achieve the modifications to the backbone model outside the backbone model's code.
|
||||
|
||||
|
|
|
@ -0,0 +1,185 @@
|
|||
(namebasedaddr)=
|
||||
# Name-based Addressing
|
||||
|
||||
Named based addressing is what set OpenDelta apart from other packages and provide the possibility to be used to a broader range of models (even emerging ones).
|
||||
|
||||
|
||||
## Name of a submodule.
|
||||
We locate the submodules that we want to apply a delta layer via name-based addressing.
|
||||
|
||||
In pytorch fashion, a submodule can be accessed from a root model via 'dot' addressing. For example, we define a toy language model
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
class MyNet1(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.name_a = nn.Linear(5,5)
|
||||
def forward(self, hiddens):
|
||||
return self.name_a(hiddens)
|
||||
|
||||
class MyNet2(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(10,5)
|
||||
self.name_b = nn.Sequential(MyNet1(), MyNet1())
|
||||
def forward(self, input_ids):
|
||||
hiddens = self.embedding(input_ids)
|
||||
return self.name_b(hiddens)
|
||||
|
||||
root = MyNet2()
|
||||
print(root.name_b[0].name_a)
|
||||
# Linear(in_features=5, out_features=5, bias=True)
|
||||
```
|
||||
|
||||
We can visualize the model (For details, see [visualization](visualization))
|
||||
|
||||
```python
|
||||
from opendelta import Visualization
|
||||
Visualization(root).structure_graph()
|
||||
```
|
||||
|
||||
````{collapse} <span style="color:rgb(141, 99, 224);font-weight:bold;font-style:italic">Click to view output</span>
|
||||
```{figure} ../imgs/name_based_addressing.png
|
||||
---
|
||||
width: 500px
|
||||
name: name_based_addressing
|
||||
---
|
||||
```
|
||||
````
|
||||
|
||||
In this case, string `"name_b.0.name_a"` will be the name to address the submodule from the root model.
|
||||
|
||||
Thus when applying a delta model to this toy net.
|
||||
|
||||
```
|
||||
from opendelta import AdapterModel
|
||||
AdapterModel(backbone_model=root, modified_modules=['name_b.0.name_a'])
|
||||
Visualization(root).structure_graph()
|
||||
```
|
||||
|
||||
````{collapse} <span style="color:rgb(141, 99, 224);font-weight:bold;font-style:italic">Click to view output</span>
|
||||
```{figure} ../imgs/toy-delta.png
|
||||
---
|
||||
width: 500px
|
||||
name: toy-delta
|
||||
---
|
||||
```
|
||||
````
|
||||
|
||||
|
||||
## Target modules.
|
||||
|
||||
For different delta methods, the operation for the modification target is different.
|
||||
- Adapter based method: Insert at the target module's forward function.
|
||||
- BitFit: Add bias to all allowed position of the target module.
|
||||
- Lora: Substitute the all the linear layers of the target module with [Lora.Linear](https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L92).
|
||||
- Prefix Tuning: the target module must be an attention module.
|
||||
|
||||
:::{admonition} Auto Searching
|
||||
:class: note
|
||||
We are working on unifying operations to automatically search within a given module for its submodules that can be applied using a specific delta method.
|
||||
:::
|
||||
|
||||
## Makes addressing easier.
|
||||
|
||||
Handcrafting the full names of submodules can be frustrating. We made some simplifications
|
||||
|
||||
1. **End-matching** Rules.
|
||||
|
||||
OpenDelta will take every modules that
|
||||
**ends with** the provided name suffix as the modification [target module](target_module).
|
||||
:::{admonition} Example
|
||||
:class: tip
|
||||
Taking DistilBert with an classifier on top as an example:
|
||||
- set to `["0.attention.out_lin"]` will add delta modules to the attention output of distilbert's
|
||||
ayer 0, i.e., `distilbert.transformer.layer.0.attention.out_lin`.
|
||||
- set to `["attention.out_lin"]` will add the delta modules in every layer's `attention.out_lin`.
|
||||
:::
|
||||
|
||||
|
||||
(regexexpr)=
|
||||
2. Regular Expression.
|
||||
|
||||
We also support regex end-matching rules.
|
||||
We use a beginning `[r]` followed by a regular expression to represent this rule, where `[r]` is used to distinguish it from normal string matching rules and has no other meanings.
|
||||
|
||||
Taking RoBERTa with an classifier on top as an example: It has two modules named `roberta.encoder.layer.0.attention.output.dense` and `roberta.encoder.layer.0.output.dense`, which both end up with `output.dense`. To distinguish them:
|
||||
|
||||
- set `'[r](\d)+\.output.dense'` using regex rules, where `(\d)+` match any layer numbers. This rule will match all `roberta.encoder.layer.$.output.dense`. where `$` represents all integer numbers, here in a 12-layer RoBERTa, it's 0-11.
|
||||
|
||||
- set `'[r][0-5]\.attention'` will match only the 0-5 layers' attention submodule.
|
||||
|
||||
- set `'attention.output.dense'` using ordinary rules, which only match `roberta.encoder.layer.0.attention.output.dense`.
|
||||
|
||||
:::{admonition} Regex in Json Configs
|
||||
:class: warning
|
||||
In json, you should write `"\\."` instead of `"\."` for a real dot due to json parsing rules. That is
|
||||
```json
|
||||
{
|
||||
...
|
||||
"modified_moduls": ['[r][0-5]\\.attention'],
|
||||
...
|
||||
}
|
||||
```
|
||||
:::
|
||||
|
||||
|
||||
3. Interactive Selection.
|
||||
|
||||
We provide a way to interact visually to select modules needed.
|
||||
|
||||
```python
|
||||
from transformers import BertForMaskedLM
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-cased")
|
||||
# suppose we load BERT
|
||||
|
||||
from opendelta import LoraModel # use lora as an example, others are same
|
||||
delta_model = LoraModel(backbone_model=model, interactive_modify=True)
|
||||
```
|
||||
|
||||
by setting `interactive_modify`, a web server will be opened on local host, and the link will be print in the terminal.
|
||||
|
||||
```
|
||||
http://0.0.0.0:8888/
|
||||
```
|
||||
|
||||
If on your local machine, click to open the link for interactive modification.
|
||||
|
||||
If on remote host, you could use port mapping. For example, vscode terminal will automatically do port mapping for you, you can simply use `control/command + click` to open the link.
|
||||
|
||||
You can change the port number in case the default port number is occupied by other program by setting `interactive_modify=port_number`, in which port_number is an integer.
|
||||
|
||||
The web page looks like the following figure.
|
||||
|
||||
```{figure} ../imgs/interact.jpg
|
||||
---
|
||||
width: 500px
|
||||
name: interact web page
|
||||
---
|
||||
```
|
||||
|
||||
- By clicking on `[+]`/`[-]` to expand / collapse tree nodes.
|
||||
|
||||
- By clicking on text to select tree nodes, **yellow dotted** box indicates the selection.
|
||||
|
||||
- **Double** click on the pink `[*]` is an advanced option to unfold the repeated nodes. By default, modules with the same architecture are folded into one node and are marked in red, for example, the `BertLayer` of layers 0~11 in the above figure are in the same structure. Regular model changes will make the same changes to each layers.
|
||||
|
||||
- If you want to change only a few of them, first double-click on `[*]`, then select the parts you want in the unfolded structure.
|
||||
|
||||
- If you want to make the same change to all but a few of them, first select the common parts you want in the folded structure, then double-click on `[*]` to remove the few positions you don't need to change in the expanded structure.
|
||||
|
||||
Click `submit` button on the top-right corner, then go back to your terminal, you can get a list of name-based addresses printed in the terminal in the following format, and these modules are being "delta".
|
||||
|
||||
```
|
||||
modified_modules:
|
||||
[bert.encoder.layer.0.output.dense, ..., bert.encoder.layer.11.output.dense]
|
||||
```
|
||||
|
||||
|
||||
## Examples
|
||||
Nothing works better than a few lively examples.
|
||||
Comming Soon...
|
||||
|
||||
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
(unifyname)=
|
||||
|
||||
# Unified Name Convention
|
||||
# Common Structure Mapping
|
||||
|
||||
```{figure} ../imgs/transformers_structure.png
|
||||
:width: 400px
|
||||
|
@ -44,8 +44,8 @@ Visualize bert-base using a common structure name: The submodules that are not c
|
|||
:name: transformers_structure
|
||||
```
|
||||
|
||||
(commonstructure)=
|
||||
## Mappings
|
||||
(mappingexample)=
|
||||
## Example
|
||||
|
||||
Example of bert mapping: a tree with node names specified by <span style="font-weight:bold;color:rgb(55, 125, 34);" >"\_\_name\_\_"</span>
|
||||
```json
|
||||
|
|
|
@ -43,19 +43,12 @@ delta_model = AdapterModel(backbone_model=model, modified_modules=['fc2'], bottl
|
|||
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||
```
|
||||
|
||||
(target_module)=
|
||||
:::{admonition} Target module
|
||||
:class: note
|
||||
For different delta methods, the operation for the modification target is different.
|
||||
- Adapter based method: Insert at the target module's forward function.
|
||||
- BitFit: Add bias to all allowed position of the target module.
|
||||
- Lora: Substitute the all the linear layers of the target module with [Lora.Linear](https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L92).
|
||||
:::
|
||||
|
||||
|
||||
### 2.2 Use the default modification.
|
||||
We also provide the default modifications of each delta methods for some commonly used PTMs (e.g., BERT, RoBERTA, DistilBERT, T5, GPT2), so the users don't need to specify the submodules to modify.
|
||||
|
||||
The default modifications is achieved by a [common_structure mapping](commonstructure), that is, use the mapping a name of a module to the it's name on a common transformer structure. <img src="../imgs/hint-icon-2.jpg" height="30px"> *For details about the default modification, see [Unified Name Convention](unifyname)*
|
||||
The default modifications is achieved by mapping a name of a submodule to it's name on a common transformer structure. <img src="../imgs/hint-icon-2.jpg" height="30px"> *For details about the common structure mapping, see [Common Structure Mapping](unifyname)*
|
||||
|
||||
|
||||
|
||||
|
@ -97,7 +90,7 @@ The performance may vary due to positional differences, but there is no academic
|
|||
|
||||
:::{admonition} Favored Configurations
|
||||
:class: tip
|
||||
Feel confused about the flexibility that OpenDelta brings? NO WORRY! We will add [Favored Configurations](favoredconfiguration) soon.
|
||||
Feel confused about the flexibility that OpenDelta brings? Currently you can refer to the papers for their configuration. And We will add [Favored Configurations](favoredconfiguration) soon.
|
||||
:::
|
||||
|
||||
## STEP 3: Freezing parameters
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
"per_device_eval_batch_size": 32,
|
||||
"per_device_train_batch_size": 32,
|
||||
"predict_with_generate": true,
|
||||
"push_to_hub": true,
|
||||
"push_to_hub": false,
|
||||
"save_steps": 200,
|
||||
"save_strategy": "steps",
|
||||
"save_total_limit": 1,
|
||||
|
@ -39,5 +39,9 @@
|
|||
"deltas",
|
||||
"classifier"
|
||||
],
|
||||
"modified_modules":[
|
||||
"[r][0-5]\\.attention"
|
||||
],
|
||||
"reparameterize": false,
|
||||
"warmup_steps": 0
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
from transformers import AutoModelForSequenceClassification
|
||||
model = AutoModelForSequenceClassification.from_pretrained("roberta-base")
|
||||
# suppose we load BART
|
||||
|
||||
from opendelta import Visualization
|
||||
print("before modify")
|
||||
Visualization(model).structure_graph()
|
||||
|
||||
from opendelta import LoraModel
|
||||
import re
|
||||
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
|
||||
print("after modify")
|
||||
delta_model.log()
|
||||
# This will visualize the backbone after modification and other information.
|
|
@ -262,14 +262,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
if is_leaf_module(module):
|
||||
for n, p in module.named_parameters():
|
||||
if self.find_key(".".join([prefix,n]), exclude, only_tail=True):
|
||||
if self.find_key(".".join([prefix,n]), exclude):
|
||||
continue
|
||||
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
|
||||
p.requires_grad = False
|
||||
return
|
||||
else:
|
||||
for n, c in module.named_children():
|
||||
if self.find_key(".".join([prefix,n]), exclude, only_tail=True): # if found, untouch the parameters
|
||||
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
|
||||
continue
|
||||
else: # firstly freeze the non module params, then go deeper.
|
||||
params = non_module_param(module)
|
||||
|
@ -282,14 +282,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
|
||||
|
||||
def find_key(self, key: Union[str, re.Pattern], target_list: List[str], only_tail=True):
|
||||
def find_key(self, key: str, target_list: List[str]):
|
||||
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
|
||||
|
||||
Args:
|
||||
key (Union[:obj:`str`, :obj:`re.Pattern`]): The key (name) of a submodule in a ancestor module.
|
||||
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
|
||||
E.g., model.encoder.layer.0.attention
|
||||
target_list (List[:obj:`str`]): The target list that we try to match ``key`` with. E.g., ["attention"]
|
||||
only_tail (:obj:`bool`): the element in the target_list should be in the tail of key
|
||||
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
|
||||
|
||||
Returns:
|
||||
:obj:`bool` True if the key matchs the target list.
|
||||
|
@ -299,19 +298,9 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
if not key:
|
||||
return False
|
||||
try:
|
||||
if isinstance(key, re.Pattern): # TODO: unit test needed ERROR
|
||||
if only_tail:
|
||||
return endswith_in_regex(key, target_list)
|
||||
else:
|
||||
return substring_in_regex(key, target_list)
|
||||
else:
|
||||
if only_tail:
|
||||
return endswith_in(key, target_list)
|
||||
else:
|
||||
return substring_in(key, target_list)
|
||||
return endswith_in(key, target_list)
|
||||
except:
|
||||
from IPython import embed
|
||||
embed(header = "exception")
|
||||
raise RuntimeError("find_key exception")
|
||||
|
||||
def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None):
|
||||
r"""Create a pseudo_data into the module to know the dimemsion of each tensor in the computation graph.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Union
|
||||
import re
|
||||
def superstring_in(str_a: str , list_b: List[str]):
|
||||
r"""check whether there is any string in list b containing str_a.
|
||||
|
@ -16,7 +16,11 @@ def is_child_key(str_a: str , list_b: List[str]):
|
|||
"""
|
||||
return any(str_b in str_a and (str_b==str_a or str_a[len(str_b)]==".") for str_b in list_b)
|
||||
|
||||
def endswith_in(str_a: str , list_b: List[str]):
|
||||
def endswith_in(str_a: str, list_b: List[str]):
|
||||
return endswith_in_regex(str_a, [b[3:] for b in list_b if b.startswith("[r]")]) or \
|
||||
endswith_in_normal(str_a, [b for b in list_b if not b.startswith("[r]")])
|
||||
|
||||
def endswith_in_normal(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
Args:
|
||||
|
@ -24,20 +28,6 @@ def endswith_in(str_a: str , list_b: List[str]):
|
|||
"""
|
||||
return any(str_a.endswith(str_b) and (str_a==str_b or str_a[-len(str_b)-1] == ".") for str_b in list_b)
|
||||
|
||||
def substring_in(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
"""
|
||||
token_a = str_a.split(".")
|
||||
for str_b in list_b:
|
||||
token_b = str_b.split(".")
|
||||
for i in range(len(token_a)-len(token_b)+1):
|
||||
if "".join(token_a[i:i+len(token_b)]) == "".join(token_b):
|
||||
return True
|
||||
return False
|
||||
|
||||
def endswith_in_regex(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
|
@ -45,7 +35,7 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
|
|||
Returns:
|
||||
"""
|
||||
for str_b in list_b:
|
||||
ret = re.search(str_b, str_a)
|
||||
ret = re.search(re.compile(str_b), str_a)
|
||||
if ret is not None:
|
||||
b = ret.group()
|
||||
if ret.span()[1] == len(str_a) and (b == str_a or (str_a==b or str_a[-len(b)-1] == ".")):
|
||||
|
@ -53,19 +43,4 @@ def endswith_in_regex(str_a: str , list_b: List[str]):
|
|||
return True
|
||||
return False
|
||||
|
||||
def substring_in_regex(str_a: str , list_b: List[str]):
|
||||
r"""check whether ``str_a`` has a substring that is in list_b.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
"""
|
||||
for str_b in list_b:
|
||||
ret = re.search(str_b, str_a)
|
||||
if ret is not None:
|
||||
b = ret.group()
|
||||
if (ret.span()[0] == 0 or str_a[ret.span()[0]-1] == ".") and \
|
||||
(ret.span()[1] == len(str_a) or str_a[ret.span()[1]] == "."): #and b == str_a and (str_a==b or str_a[-len(b)-1] == "."):
|
||||
# the latter is to judge whether it is a full sub key in the str_a, e.g. str_a=`attn.c_attn` and list_b=[`attn`] will given False
|
||||
return True
|
||||
return False
|
||||
|
Loading…
Reference in New Issue