fix bmtrain tutorial
This commit is contained in:
parent
ecbfc0f244
commit
4315b83c8e
|
@ -9,7 +9,15 @@ from sklearn.metrics import accuracy_score, recall_score, f1_score
|
||||||
|
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
|
|
||||||
from model_center import get_args
|
from model_center.arguments import add_model_config_args, add_training_args, argparse
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = add_model_config_args(parser)
|
||||||
|
parser = add_training_args(parser)
|
||||||
|
group = parser.add_argument_group('delta', 'delta configurations')
|
||||||
|
group.add_argument('--delta-type', '--delta_type', type=str, help='delta type')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
from model_center.model import Bert
|
from model_center.model import Bert
|
||||||
from model_center.tokenizer import BertTokenizer
|
from model_center.tokenizer import BertTokenizer
|
||||||
from model_center.dataset.bertdataset import DATASET
|
from model_center.dataset.bertdataset import DATASET
|
||||||
|
|
Loading…
Reference in New Issue