fix bmtrain tutorial

This commit is contained in:
Achazwl 2022-11-21 03:17:33 +00:00
parent 62e31a69ff
commit ef2189ddb1
1 changed files with 5 additions and 1 deletions

View File

@ -1,5 +1,8 @@
# adapted from https://github.com/OpenBMB/ModelCenter/blob/main/examples/bert/finetune_bert.py
# For ModelCenter, `pip install model_center >= 1.0.0`
# For BMTrain, `git clone https://github.com/OpenBMB/BMTrain.git` and `python3 setup.py install` locally, as it has not been released currently.
import time
import os
@ -56,7 +59,8 @@ def get_model(args):
"WiC" : 2,
}
model = BertModel(args, num_types[args.dataset_name])
# od.Visualization(model).structure_graph()
from bigmodelvis import Visualization
Visualization(model).structure_graph()
if args.delta_type == "lora":
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')