diff --git a/examples/tutorial/2_with_bmtrain.py b/examples/tutorial/2_with_bmtrain.py index 072a0cb..45b0d99 100644 --- a/examples/tutorial/2_with_bmtrain.py +++ b/examples/tutorial/2_with_bmtrain.py @@ -61,7 +61,7 @@ def get_model(args): if args.delta_type == "lora": delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt') elif args.delta_type == "bitfit": - delta_model = BitFitModel(backbone_model=model, modified_modules=['self_att', 'ffn', 'layernorm'], backend='bmt') #TODO: fix bug + delta_model = BitFitModel(backbone_model=model, modified_modules=['self_att', 'ffn', 'layernorm'], backend='bmt') elif args.delta_type == "adapter": delta_model = AdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt') elif args.delta_type == "compacter":