fix bmtrain tutorial

This commit is contained in:
Achazwl 2022-11-17 14:59:23 +00:00
parent ecbfc0f244
commit 4315b83c8e
1 changed files with 9 additions and 1 deletions

View File

@ -9,7 +9,15 @@ from sklearn.metrics import accuracy_score, recall_score, f1_score
import bmtrain as bmt import bmtrain as bmt
from model_center import get_args from model_center.arguments import add_model_config_args, add_training_args, argparse
def get_args():
parser = argparse.ArgumentParser()
parser = add_model_config_args(parser)
parser = add_training_args(parser)
group = parser.add_argument_group('delta', 'delta configurations')
group.add_argument('--delta-type', '--delta_type', type=str, help='delta type')
args = parser.parse_args()
return args
from model_center.model import Bert from model_center.model import Bert
from model_center.tokenizer import BertTokenizer from model_center.tokenizer import BertTokenizer
from model_center.dataset.bertdataset import DATASET from model_center.dataset.bertdataset import DATASET