fix tutorial visualize
This commit is contained in:
parent
62e31a69ff
commit
798312d72d
|
@ -34,7 +34,7 @@ Then use the config to add a delta model to the backbone model
|
||||||
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=backbone_model)
|
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=backbone_model)
|
||||||
|
|
||||||
# now visualize the modified backbone_model
|
# now visualize the modified backbone_model
|
||||||
from opendelta import Visualization
|
from bigmodelvis import Visualization
|
||||||
Visualizaiton(backbone_model).structure_graph()
|
Visualizaiton(backbone_model).structure_graph()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ Delta tuning's core change in the structure of the base model is to decorate (mo
|
||||||
We should first know the name of the feedforward layer in the BART model by visualization. <img src="../imgs/hint-icon-2.jpg" height="30px"> *For more about visualization, see [Visualization](visualization).*
|
We should first know the name of the feedforward layer in the BART model by visualization. <img src="../imgs/hint-icon-2.jpg" height="30px"> *For more about visualization, see [Visualization](visualization).*
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from opendelta import Visualization
|
from bigmodelvis import Visualization
|
||||||
Visualization(model).structure_graph()
|
Visualization(model).structure_graph()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -30,12 +30,12 @@ name: raw_print
|
||||||
The original presentation of models is **not tailored for repeated structures, big models, or parameters-centric tasks**.
|
The original presentation of models is **not tailored for repeated structures, big models, or parameters-centric tasks**.
|
||||||
|
|
||||||
|
|
||||||
## Using visualization from opendelta.
|
## Using visualization from bigmodelvis.
|
||||||
|
|
||||||
First let's visualize all the parameters in the bert model. As we can see, structure inside a bert model, and the all the paramters location of the model are neatly represented in tree structure. (See [color scheme](color_schema) for the colors)
|
First let's visualize all the parameters in the bert model. As we can see, structure inside a bert model, and the all the paramters location of the model are neatly represented in tree structure. (See [color scheme](color_schema) for the colors)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from opendelta import Visualization
|
from bigmodelvis import Visualization
|
||||||
model_vis = Visualization(backbone_model)
|
model_vis = Visualization(backbone_model)
|
||||||
model_vis.structure_graph()
|
model_vis.structure_graph()
|
||||||
```
|
```
|
||||||
|
|
|
@ -35,7 +35,7 @@ print(root.name_b[0].name_a)
|
||||||
We can visualize the model (For details, see [visualization](visualization))
|
We can visualize the model (For details, see [visualization](visualization))
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from opendelta import Visualization
|
from bigmodelvis import Visualization
|
||||||
Visualization(root).structure_graph()
|
Visualization(root).structure_graph()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ from model_center.dataset import DistributedDataLoader
|
||||||
import opendelta as od
|
import opendelta as od
|
||||||
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
|
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
|
||||||
from opendelta.utils.inspect import inspect_optimizer_statistics
|
from opendelta.utils.inspect import inspect_optimizer_statistics
|
||||||
|
from bigmodelvis import Visualization
|
||||||
print("before modify")
|
print("before modify")
|
||||||
|
|
||||||
class BertModel(torch.nn.Module):
|
class BertModel(torch.nn.Module):
|
||||||
|
@ -56,7 +57,7 @@ def get_model(args):
|
||||||
"WiC" : 2,
|
"WiC" : 2,
|
||||||
}
|
}
|
||||||
model = BertModel(args, num_types[args.dataset_name])
|
model = BertModel(args, num_types[args.dataset_name])
|
||||||
# od.Visualization(model).structure_graph()
|
Visualization(model).structure_graph()
|
||||||
|
|
||||||
if args.delta_type == "lora":
|
if args.delta_type == "lora":
|
||||||
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
|
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
|
||||||
|
|
Loading…
Reference in New Issue