OpenDeltaMirror/examples/unittest/user_defined.py

183 lines
4.8 KiB
Python
Raw Normal View History

2022-10-21 14:46:25 +08:00
from argparse import ArgumentParser
import logging
import os
import sys
import torch.nn as nn
logger = logging.getLogger(__name__)
class UnitTest:
def __init__(self, models):
self.models = models
self.Configs = {}
self.Configs[0] = {
"delta_type": "lora",
}
self.Configs[1] = {
"delta_type": "bitfit",
}
self.Configs[2] = {
"delta_type": "adapter",
}
self.Configs[3] = {
"delta_type": "compacter",
}
self.Configs[4] = {
"delta_type": "prefix",
}
self.Configs[5] = {
"delta_type": "soft_prompt",
}
self.Configs[6] = {
"delta_type": "low_rank_adapter",
}
def get_delta_config(self, config_id):
return self.Configs[config_id]
def unitTest0(self, delta_config_dict):
model = self.models[0]
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import AutoDeltaConfig, AutoDeltaModel
delta_config = AutoDeltaConfig.from_dict(delta_config_dict)
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model = model)
from opendelta import Visualization
Visualization(model).structure_graph()
def unitTest1(self, delta_config_dict):
class Mymodel(nn.Module):
def __init__(self, a,b):
super().__init__()
self.a = a
self.b = b
model = Mymodel(self.models[0], self.models[1])
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import AutoDeltaConfig, AutoDeltaModel
delta_config = AutoDeltaConfig.from_dict(delta_config_dict)
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model = model)
from opendelta import Visualization
Visualization(model).structure_graph()
delta_model.save_finetuned("./tmp")
delta_model.freeze_module(exclude=['deltas'])
delta_model.save_finetuned("./tmp")
model = Mymodel(self.models[0], self.models[1])
Visualization(model).structure_graph()
delta_model = AutoDeltaModel.from_finetuned("./tmp", backbone_model=model)
Visualization(model).structure_graph()
def unit_test(self, test_id, config_id):
delta_config_dict = self.Configs[config_id]
if test_id == 0:
self.unitTest0(delta_config_dict)
elif test_id == 1:
self.unitTest1(delta_config_dict)
from dataclasses import dataclass, field
@dataclass
class UnitTestArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
config_id: int = field(
default=0,
)
test_id: int = field(
default=0,
)
model_name_or_path: str =field(
default='bert-base-cased',
metadata={"help": "tested: bert-base-cased, roberta-base, rinna/japanese-gpt2-small, t5-small, facebook/opt-125m"}
)
from transformers import HfArgumentParser,TrainingArguments, AutoModel, GPT2Model
def main():
parser = HfArgumentParser((TrainingArguments, UnitTestArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
training_args, unit_test_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
training_args, unit_test_args = parser.parse_args_into_dataclasses()
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
model = AutoModel.from_pretrained(unit_test_args.model_name_or_path)
import torch
import copy
models = [model, copy.deepcopy(model)]
unit_test = UnitTest(models)
unit_test.unit_test(unit_test_args.test_id, unit_test_args.config_id)
if __name__ == "__main__":
main()