diff --git a/.gitignore b/.gitignore index 22d883c..1ca34e3 100644 --- a/.gitignore +++ b/.gitignore @@ -54,8 +54,13 @@ t.sh **/delta_checkpoints/ **/outputs/ +dist/* **/unittest/** !unittest/**.py !unittest/**.sh +**/tutorial/** +!tutorial/**.py +!tutorial/**.sh +!tutorial/**.md diff --git a/docs/source/conf.py b/docs/source/conf.py index 1be8e51..8a94518 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -31,8 +31,8 @@ copyright = '{}, {}, Licenced under the Apache License, Version 2.0'.format(date # The full version, including alpha/beta/rc tags -release = '0.3.1' -version = "0.3.1" +release = '0.3.2' +version = "0.3.2" html_theme = 'sphinx_rtd_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] diff --git a/docs/source/notes/acceleration.md b/docs/source/notes/acceleration.md index 6b088b4..4bb258f 100644 --- a/docs/source/notes/acceleration.md +++ b/docs/source/notes/acceleration.md @@ -1,6 +1,14 @@ (acceleration)= # OpenDelta+ - We are working on testing and improving the functionality with work with other acceleration packages for model training and inference. For example, [deepspeed](https://github.com/microsoft/DeepSpeed), [BMInf](https://github.com/OpenBMB/BMInf). -Feel free to contact us via email (shengdinghu@gmail.com) if you have any suggestion. +## BMTrain + +- [BMTrain](https://github.com/OpenBMB/BMTrain) is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training. +- [ModelCenter](https://github.com/OpenBMB/ModelCenter) implements pre-trained language models (PLMs) based on the backend OpenBMB/BMTrain. ModelCenter supports Efficient, Low-Resource, Extendable model usage and distributed training. + +Now we have the LoraModel, AdapterModel, CompacterModel, ParallelAdapterModel, LowRankAdapterModel fully supported the distributed training with BMTrain and ModelCenter. Please try is out in + + +## Huggingface Accelerate + \ No newline at end of file diff --git a/docs/source/notes/update.md b/docs/source/notes/update.md index 26626d2..e121178 100644 --- a/docs/source/notes/update.md +++ b/docs/source/notes/update.md @@ -1,5 +1,10 @@ # Update Logs and Known Issues +## Version 0.3.2 +- We support BMTrain to accelerate the training, and parallelize the training of models that are hard to fit in a single GPU. Check [tutorial/2_with_bmtrain.py](https://github.com/thunlp/OpenDelta/tree/main/examples/tutorial/2_with_bmtrain.py) +- We add a functionality to [inspect the optimizer](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py). The user can see the number of trainable parameters in the optimizer and verify that opendelta is being used correctly. +- We move the functions to inspect the delta models into [inspect.py](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py) + ## Version 0.3.1 - We update [must_try.py](https://github.com/thunlp/OpenDelta/tree/main/examples/unittest/must_try.py) for a simple introduction of the core functionality of OpenDelta. - Thanks to [Weilin Zhao](https://github.com/Achazwl) We merge a long-developed branch parallel_adapter into the main branch. diff --git a/examples/tutorial/2_with_bmtrain.py b/examples/tutorial/2_with_bmtrain.py index d543355..ba202c9 100644 --- a/examples/tutorial/2_with_bmtrain.py +++ b/examples/tutorial/2_with_bmtrain.py @@ -1,50 +1,291 @@ -import bmtrain as bmt -import opendelta as od -from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel +# adapted from https://github.com/OpenBMB/ModelCenter/blob/main/examples/bert/finetune_bert.py + +import time +import os + import torch -import numpy -import random +import numpy as np +from sklearn.metrics import accuracy_score, recall_score, f1_score -def manual_seed(seed): - torch.manual_seed(seed) - numpy.random.seed(seed) - random.seed(seed) +import bmtrain as bmt -from model_center.model import Bert, BertConfig -bmt.init_distributed() -config = BertConfig.from_pretrained("bert-base-uncased") -config.dropout_p = 0 -model = Bert.from_pretrained("bert-base-uncased", config) +from model_center import get_args +from model_center.model import Bert +from model_center.tokenizer import BertTokenizer +from model_center.dataset.bertdataset import DATASET +from model_center.utils import print_inspect +from model_center.layer import Linear +from model_center.dataset import DistributedDataLoader +import opendelta as od +from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel +from opendelta.utils.inspect import inspect_optimizer_statistics print("before modify") -od.Visualization(model).structure_graph() -manual_seed(233) -delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k']) -# delta_model = AdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn']) -# delta_model = CompacterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn']) -# delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn']) -# delta_model = BitFitModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn', '[r](.*)layernorm(.*)']) +class BertModel(torch.nn.Module): + def __init__(self, args, num_types): + super().__init__() + self.bert : Bert = Bert.from_pretrained(args.model_config) + dim_model = self.bert.input_embedding.dim_model + self.dense = Linear(dim_model, num_types) + bmt.init_parameters(self.dense) -# print(delta_model.delta_modules) + def forward(self, *args, **kwargs): + pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output + logits = self.dense(pooler_output) + return logits -print("after modify") -delta_model.log() -# This will visualize the backbone after modification and other information. +def get_tokenizer(args): + tokenizer = BertTokenizer.from_pretrained(args.model_config) + return tokenizer -delta_model.freeze_module(exclude=["deltas"], set_state_dict=True) -print("after freeze") -delta_model.log() -# The set_state_dict=True will tell the method to change the state_dict of the backbone_model to maintaining only the trainable parts. +def get_model(args): + num_types = { + "BoolQ" : 2, + "CB" : 3, + "COPA" : 1, + "RTE" : 2, + "WiC" : 2, + } + model = BertModel(args, num_types[args.dataset_name]) + od.Visualization(model).structure_graph() -manual_seed(233) -inp = torch.randint(0, 30000, (32, 128)).cuda() -length = torch.randint(0, 128, (32,)).cuda() -attention_mask = (torch.arange(inp.shape[1], device=inp.device)[None, :].repeat(inp.shape[0], 1) < length[:, None]) -out = model(inp, attention_mask=attention_mask, output_logits=True).logits -print(out) -if bmt.rank() == 0: - torch.save(model.state_dict(), "test.pt") - ckpt = torch.load("test.pt") - print(ckpt.keys()) \ No newline at end of file + if args.delta_type == "lora": + delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt') + elif args.delta_type == "bitfit": + delta_model = BitFitModel(backbone_model=model, modified_modules=['self_att', 'ffn', 'layernorm'], backend='bmt') #TODO: fix bug + elif args.delta_type == "adapter": + delta_model = AdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt') + elif args.delta_type == "compacter": + delta_model = CompacterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt') + elif args.delta_type == "low_rank_adapter": + delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt') + elif args.delta_type == "parallel_adapter": + delta_model = ParallelAdapterModel(backbone_model=model, modified_modules=['self_att', 'self_att', 'ffn.ffn', 'ffn.ffn'], backend='bmt') + + + + print("after modify") + delta_model.log() + # This will visualize the backbone after modification and other information. + + delta_model.freeze_module(exclude=["deltas"], set_state_dict=True) + print("after freeze") + delta_model.log() + return model + +def get_optimizer(args, model): + optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=args.weight_decay) + return optimizer + +def get_learning_rate_scheduler(args, optimizer): + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters * args.epochs + if args.lr_decay_style == "noam": + lr_scheduler = bmt.lr_scheduler.Noam(optimizer, + start_lr = args.lr, + warmup_iter = args.warmup_iters, + end_iter = args.lr_decay_iters, + num_iter = args.start_step) + elif args.lr_decay_style == "constant": + lr_scheduler = bmt.lr_scheduler.NoDecay(optimizer, + start_lr = args.lr, + warmup_iter = args.warmup_iters, + end_iter = -1, + num_iter = args.start_step) + elif args.lr_decay_style == "linear": + lr_scheduler = bmt.lr_scheduler.Linear(optimizer, + start_lr = args.lr, + warmup_iter = args.warmup_iters, + end_iter = args.lr_decay_iters, + num_iter = args.start_step) + elif args.lr_decay_style == "exponential": + lr_scheduler = bmt.lr_scheduler.Exponential(optimizer, + start_lr = args.lr, + warmup_iter = args.warmup_iters, + end_iter = args.lr_decay_iters, + num_iter = args.start_step) + elif args.lr_decay_style == "cosine": + lr_scheduler = bmt.lr_scheduler.Cosine(optimizer, + start_lr = args.lr, + warmup_iter = args.warmup_iters, + end_iter = args.lr_decay_iters, + num_iter = args.start_step) + else: + raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.") + + return lr_scheduler + +def setup_model_and_optimizer(args): + # get the tokenizer + tokenizer = get_tokenizer(args) + # get the model + model = get_model(args) + bmt.synchronize() + # get the optimizer and lr_scheduler + optimizer = get_optimizer(args, model) + + inspect_optimizer_statistics(optimizer) + lr_scheduler = get_learning_rate_scheduler(args, optimizer) + bmt.synchronize() + # get the memory usage + bmt.print_rank("Model mem\n", torch.cuda.memory_summary()) + bmt.synchronize() + return tokenizer, model, optimizer, lr_scheduler + +def initialize(): + # get arguments + args = get_args() + # init bmt + bmt.init_distributed(seed = args.seed) + # init save folder + if args.save != None: + os.makedirs(args.save, exist_ok=True) + return args + +def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size): + splits = ['train', 'dev', 'test'] + dataset = {} + for split in splits: + dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length) + return dataset + + +def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset): + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) + + optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale) + optim_manager.add_optimizer(optimizer, lr_scheduler) + + # print_inspect(model, '*') # too much output + + for epoch in range(12): + dataloader = { + "train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True), + "dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False), + } + + model.train() + for it, data in enumerate(dataloader['train']): + if args.dataset_name == 'COPA': + input_ids0 = data["input_ids0"] + attention_mask0 = data["attention_mask0"] + token_type_ids0 = data["token_type_ids0"] + input_ids1 = data["input_ids1"] + attention_mask1 = data["attention_mask1"] + token_type_ids1 = data["token_type_ids1"] + labels = data["labels"] + else: + input_ids = data["input_ids"] + attention_mask = data["attention_mask"] + token_type_ids = data["token_type_ids"] + labels = data["labels"] + + torch.cuda.synchronize() + st_time = time.time() + + if args.dataset_name == 'COPA': + logits = torch.cat([ + model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0), + model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1), + ], dim=1) + else: + logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) + loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1)) + + global_loss = bmt.sum_loss(loss).item() + + optim_manager.zero_grad() + + optim_manager.backward(loss) + grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type = 2) + + optim_manager.step() + + torch.cuda.synchronize() + elapsed_time = time.time() - st_time + + # from IPython import embed; embed(header="25252") + + bmt.print_rank( + "train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format( + epoch, + it, + len(dataloader["train"]), + global_loss, + lr_scheduler.current_lr, + int(optim_manager.loss_scale), + grad_norm, + elapsed_time, + ) + ) + + model.eval() + with torch.no_grad(): + for split in ['dev']: + pd = [] + gt = [] + for it, data in enumerate(dataloader[split]): + if args.dataset_name == 'COPA': + input_ids0 = data["input_ids0"] + attention_mask0 = data["attention_mask0"] + token_type_ids0 = data["token_type_ids0"] + input_ids1 = data["input_ids1"] + attention_mask1 = data["attention_mask1"] + token_type_ids1 = data["token_type_ids1"] + labels = data["labels"] + logits = torch.cat([ + model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0), + model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1), + ], dim=1) + else: + input_ids = data["input_ids"] + attention_mask = data["attention_mask"] + token_type_ids = data["token_type_ids"] + labels = data["labels"] + logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) + + loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1)) + logits = logits.argmax(dim=-1) + pd.extend(logits.cpu().tolist()) + gt.extend(labels.cpu().tolist()) + + bmt.print_rank( + "{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format( + split, + epoch, + it, + len(dataloader[split]), + loss, + ) + ) + + pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist() + gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist() + + bmt.print_rank(f"{split} epoch {epoch}:") + if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]: + acc = accuracy_score(gt, pd) + bmt.print_rank(f"accuracy: {acc*100:.2f}") + if args.dataset_name in ["CB"]: + rcl = f1_score(gt, pd, average="macro") + f1 = recall_score(gt, pd, average="macro") + bmt.print_rank(f"recall: {rcl*100:.2f}") + bmt.print_rank(f"Average F1: {f1*100:.2f}") + + +def main(): + args = initialize() + tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args) + dataset = prepare_dataset( + args, + tokenizer, + f"{args.base_path}/down_data/superglue/", + args.dataset_name, + bmt.rank(), bmt.world_size(), + ) + finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/tutorial/2_with_bmtrain.sh b/examples/tutorial/2_with_bmtrain.sh index 0a5b3bc..f25013d 100644 --- a/examples/tutorial/2_with_bmtrain.sh +++ b/examples/tutorial/2_with_bmtrain.sh @@ -1 +1,37 @@ -python3 -m torch.distributed.launch --master_addr localhost --master_port 34123 --nproc_per_node $1 --nnodes 1 --node_rank 0 2_with_bmtrain.py +#! /bin/bash + +MASTER_ADDR=localhost +MASTER_PORT=12345 +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=4 + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT" + +BASE_PATH="./" +VERSION="bert-large-cased" +DATASET="BoolQ" # You can try other dataset listed in https://github.com/OpenBMB/ModelCenter/tree/main/examples/bert + +OPTS="" +OPTS+=" --model-config ${VERSION}" +OPTS+=" --base-path ${BASE_PATH}" +OPTS+=" --dataset_name ${DATASET}" +OPTS+=" --batch-size 64" +OPTS+=" --lr 0.001" # You can use different learning rate to find optimal performance +OPTS+=" --max-encoder-length 512" +OPTS+=" --train-iters 1400" +OPTS+=" --lr-decay-style constant" +OPTS+=" --weight-decay 1e-2" +OPTS+=" --clip-grad 10.0" +OPTS+=" --loss-scale 128" +OPTS+=" --delta_type low_rank_adapter" # You can use different delta type, listed in https://opendelta.readthedocs.io/en/latest/notes/acceleration.html#BMTrain + +CMD="python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} ${BASE_PATH}2_with_bmtrain.py ${OPTS}" +echo ${CMD} + +${CMD} 2>&1 | tee ${BASE_PATH}/tmp/logs/bmt_bert_boolq_finetune-${VERSION}-${DATASET}.log + diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md index 815a114..e0f2a6e 100644 --- a/examples/tutorial/README.md +++ b/examples/tutorial/README.md @@ -13,4 +13,23 @@ requirement: ``` pip install openprompt -``` \ No newline at end of file +``` + +## 2_with_bmtrain.py +1. install necessary packages: +``` +pip install git+https://github.com/OpenBMB/BMTrain.git +pip install git+git@github.com:OpenBMB/ModelCenter.git +``` +2. download dataset from https://super.gluebenchmark.com/tasks, e.g., +``` +mkdir down_data +cd down_data +wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/BoolQ.zip +unzip BoolQ.zip +``` +3. Run the shell scripts, change `NNODES`,`GPUS_PER_NODE` according to your computational resources. +``` +bash 2_with_bmtrain.sh +``` + diff --git a/examples/unittest/must_try.py b/examples/unittest/must_try.py index 1837a75..8448e41 100644 --- a/examples/unittest/must_try.py +++ b/examples/unittest/must_try.py @@ -67,6 +67,16 @@ delta2.detach() # say we add lora to the last four layer of the decoder of t5, with lora rank=5 delta_config3 = AutoDeltaConfig.from_dict({"delta_type":"lora", "modified_modules":["[r]decoder.*((20)|(21)|(22)|(23)).*DenseReluDense\.wi"], "lora_r":5}) delta3 = AutoDeltaModel.from_config(delta_config3, backbone_model=wrapped_model) +delta3.freeze_module() delta3.log() +# add optimizer as normal +from transformers import AdamW +optimizer = AdamW(wrapped_model.parameters(), lr=3e-3) + +# inspect_optimizer +from opendelta.utils.inspect import inspect_optimizer_statistics +inspect_optimizer_statistics(optimizer) + + diff --git a/examples/unittest/test_bmtrain.py b/examples/unittest/test_bmtrain.py deleted file mode 100644 index 9096fe4..0000000 --- a/examples/unittest/test_bmtrain.py +++ /dev/null @@ -1,255 +0,0 @@ -import time -import os - -import torch -import numpy as np -from sklearn.metrics import accuracy_score, recall_score, f1_score - -import bmtrain as bmt - -from model_center import get_args -from model_center.model import Bert -from model_center.tokenizer import BertTokenizer -from model_center.dataset.bertdataset import DATASET -from model_center.utils import print_inspect -from model_center.layer import Linear -from model_center.dataset import DistributedDataLoader - -class BertModel(torch.nn.Module): - def __init__(self, args, num_types): - super().__init__() - self.bert : Bert = Bert.from_pretrained(args.model_config) - dim_model = self.bert.input_embedding.dim_model - self.dense = Linear(dim_model, num_types) - bmt.init_parameters(self.dense) - - def forward(self, *args, **kwargs): - pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output - logits = self.dense(pooler_output) - return logits - -def get_tokenizer(args): - tokenizer = BertTokenizer.from_pretrained(args.model_config) - return tokenizer - -def get_model(args): - num_types = { - "BoolQ" : 2, - "CB" : 3, - "COPA" : 1, - "RTE" : 2, - "WiC" : 2, - } - model = BertModel(args, num_types[args.dataset_name]) - return model - -def get_optimizer(args, model): - optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=args.weight_decay) - return optimizer - -def get_learning_rate_scheduler(args, optimizer): - if args.lr_decay_iters is None: - args.lr_decay_iters = args.train_iters * args.epochs - if args.lr_decay_style == "noam": - lr_scheduler = bmt.lr_scheduler.Noam(optimizer, - start_lr = args.lr, - warmup_iter = args.warmup_iters, - end_iter = args.lr_decay_iters, - num_iter = args.start_step) - elif args.lr_decay_style == "constant": - lr_scheduler = bmt.lr_scheduler.NoDecay(optimizer, - start_lr = args.lr, - warmup_iter = args.warmup_iters, - end_iter = -1, - num_iter = args.start_step) - elif args.lr_decay_style == "linear": - lr_scheduler = bmt.lr_scheduler.Linear(optimizer, - start_lr = args.lr, - warmup_iter = args.warmup_iters, - end_iter = args.lr_decay_iters, - num_iter = args.start_step) - elif args.lr_decay_style == "exponential": - lr_scheduler = bmt.lr_scheduler.Exponential(optimizer, - start_lr = args.lr, - warmup_iter = args.warmup_iters, - end_iter = args.lr_decay_iters, - num_iter = args.start_step) - elif args.lr_decay_style == "cosine": - lr_scheduler = bmt.lr_scheduler.Cosine(optimizer, - start_lr = args.lr, - warmup_iter = args.warmup_iters, - end_iter = args.lr_decay_iters, - num_iter = args.start_step) - else: - raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.") - - return lr_scheduler - -def setup_model_and_optimizer(args): - # get the tokenizer - tokenizer = get_tokenizer(args) - # get the model - model = get_model(args) - bmt.synchronize() - # get the optimizer and lr_scheduler - optimizer = get_optimizer(args, model) - lr_scheduler = get_learning_rate_scheduler(args, optimizer) - bmt.synchronize() - # get the memory usage - bmt.print_rank("Model mem\n", torch.cuda.memory_summary()) - bmt.synchronize() - return tokenizer, model, optimizer, lr_scheduler - -def initialize(): - # get arguments - args = get_args() - # init bmt - bmt.init_distributed(seed = args.seed) - # init save folder - if args.save != None: - os.makedirs(args.save, exist_ok=True) - return args - -def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size): - splits = ['train', 'dev', 'test'] - dataset = {} - for split in splits: - dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length) - return dataset - - -def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset): - loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100) - - optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale) - optim_manager.add_optimizer(optimizer, lr_scheduler) - - print_inspect(model, '*') - - for epoch in range(12): - dataloader = { - "train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True), - "dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False), - } - - model.train() - for it, data in enumerate(dataloader['train']): - if args.dataset_name == 'COPA': - input_ids0 = data["input_ids0"] - attention_mask0 = data["attention_mask0"] - token_type_ids0 = data["token_type_ids0"] - input_ids1 = data["input_ids1"] - attention_mask1 = data["attention_mask1"] - token_type_ids1 = data["token_type_ids1"] - labels = data["labels"] - else: - input_ids = data["input_ids"] - attention_mask = data["attention_mask"] - token_type_ids = data["token_type_ids"] - labels = data["labels"] - - torch.cuda.synchronize() - st_time = time.time() - - if args.dataset_name == 'COPA': - logits = torch.cat([ - model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0), - model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1), - ], dim=1) - else: - logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) - loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1)) - - global_loss = bmt.sum_loss(loss).item() - - optim_manager.zero_grad() - - optim_manager.backward(loss) - grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type = 2) - - optim_manager.step() - - torch.cuda.synchronize() - elapsed_time = time.time() - st_time - - bmt.print_rank( - "train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format( - epoch, - it, - len(dataloader["train"]), - global_loss, - lr_scheduler.current_lr, - int(optim_manager.loss_scale), - grad_norm, - elapsed_time, - ) - ) - - model.eval() - with torch.no_grad(): - for split in ['dev']: - pd = [] - gt = [] - for it, data in enumerate(dataloader[split]): - if args.dataset_name == 'COPA': - input_ids0 = data["input_ids0"] - attention_mask0 = data["attention_mask0"] - token_type_ids0 = data["token_type_ids0"] - input_ids1 = data["input_ids1"] - attention_mask1 = data["attention_mask1"] - token_type_ids1 = data["token_type_ids1"] - labels = data["labels"] - logits = torch.cat([ - model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0), - model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1), - ], dim=1) - else: - input_ids = data["input_ids"] - attention_mask = data["attention_mask"] - token_type_ids = data["token_type_ids"] - labels = data["labels"] - logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) - - loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1)) - logits = logits.argmax(dim=-1) - pd.extend(logits.cpu().tolist()) - gt.extend(labels.cpu().tolist()) - - bmt.print_rank( - "{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format( - split, - epoch, - it, - len(dataloader[split]), - loss, - ) - ) - - pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist() - gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist() - - bmt.print_rank(f"{split} epoch {epoch}:") - if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]: - acc = accuracy_score(gt, pd) - bmt.print_rank(f"accuracy: {acc*100:.2f}") - if args.dataset_name in ["CB"]: - rcl = f1_score(gt, pd, average="macro") - f1 = recall_score(gt, pd, average="macro") - bmt.print_rank(f"recall: {rcl*100:.2f}") - bmt.print_rank(f"Average F1: {f1*100:.2f}") - - -def main(): - args = initialize() - tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args) - dataset = prepare_dataset( - args, - tokenizer, - f"{args.base_path}/down_data/superglue/", - args.dataset_name, - bmt.rank(), bmt.world_size(), - ) - finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/opendelta/__init__.py b/opendelta/__init__.py index 431cfa1..1127d70 100644 --- a/opendelta/__init__.py +++ b/opendelta/__init__.py @@ -21,6 +21,7 @@ from .delta_models.adapter import AdapterModel from .delta_models.prefix import PrefixModel from .delta_models.soft_prompt import SoftPromptModel from .delta_models.low_rank_adapter import LowRankAdapterModel +from .delta_models.parallel_adapter import ParallelAdapterModel from .utils.visualization import Visualization diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index 27ff96d..cd1006b 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -5,6 +5,7 @@ from multiprocessing.sharedctypes import Value import os from turtle import back from opendelta.delta_configs import BaseDeltaConfig +from opendelta.utils.inspect import inspect_module_statistics from opendelta.utils.model_md5 import gen_model_hash from opendelta.utils.signature import get_arg_names, signature from typing import Optional, Union @@ -27,6 +28,7 @@ from opendelta.utils.cuda import move_dict_to_cuda import sys from opendelta.utils.data_parallel import caller_map +from opendelta.utils.backend import BackendMapping logger = logging.get_logger(__name__) def is_leaf_module(module): @@ -94,6 +96,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): config_class = BaseDeltaConfig default_unfrozen_modules = ["deltas"] _need_pseudo_data = True + _supported_backends = ['hf'] def __init__(self, backbone_model: nn.Module, modified_modules: Optional[List[str]] = None, @@ -101,7 +104,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): unfrozen_modules: Optional[List[str]] = None, interactive_modify: Optional[Union[bool, int]] = False, common_structure: Optional[bool] = False, - framework_type: Optional[str]= "hf", # select from ["hf", "bmt"] + backend: Optional[str]= "hf", # select from ["hf", "bmt"] ): nn.Module.__init__(self) # register the backbone model after init using self.__dict__ method to avoid adding backbone_model @@ -139,7 +142,10 @@ class DeltaBase(nn.Module, SaveLoadMixin): self.unfrozen_modules = self.default_unfrozen_modules if self.common_structure and self.structure_mapping is None: raise RuntimeError("Using common structure but the structure mapping is None") - self.framework_type = framework_type + if backend not in self._supported_backends: + raise RuntimeError("Currently, backend `{}` is not supported for `{}`".format(backend, self.__class__.__name__)) + self.backend = backend + self.backend_mapping = BackendMapping(backend) def forward(self, *args, **kwargs) -> RuntimeError: r""" @@ -371,10 +377,11 @@ class DeltaBase(nn.Module, SaveLoadMixin): _auto_dummy_fail = False try: module(**dummy_inputs) - except: + except Exception as e: _auto_dummy_fail = True - if _auto_dummy_fail: - raise AttributeError(f"\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.") + + if _auto_dummy_fail and _auto_dummy: + raise AttributeError(f"str({e})\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.") @@ -684,65 +691,16 @@ class DeltaBase(nn.Module, SaveLoadMixin): from opendelta import Visualization Visualization(module).structure_graph() - self.get_statistics(module) + self.stat = inspect_module_statistics(module, verbose=False) if trainable_ratio: - logger.info("Trainable Ratio: {:2f}%".format(self.stat['trainable_ratio']*100)) + logger.info("Trainable Ratio: {}/{}={:.6f}%".format(self.stat['trainable_parameters'], self.stat['total_parameters'], self.stat['trainable_ratio']*100)) if delta_ratio: - logger.info("Delta Parameter Ratio: {:2f}%".format(self.stat['delta_ratio']*100)) + logger.info("Delta Parameter Ratio: {}/{}={:.6f}%".format(self.stat['delta_parameters'], self.stat['total_parameters'],self.stat['delta_ratio']*100)) if cuda_memory: logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem'])) - def get_statistics(self, module=None): - r"""Get the statistics of the parameters in the delta modules. - Args: - module (:obj:`nn.Module`, *optional*): The module to compute the statistics. - - Returns: - :obj:`dict`: The statistics of the parameters in the delta modules. - - """ - if module is None: - module = self.backbone_model - - self.stat = {} - n_trainable = self.num_trainable_parameters(module) - n_total = self.num_total_parameters(module) - - self.stat['trainable_ratio'] = n_trainable/n_total - - n_delta = self.num_delta_parameters(module) - n_total = self.num_total_parameters(module) - self.stat['delta_ratio'] = n_delta/n_total - - cudamem = 0 - maxcudamem = 0 - for device_id in range(torch.cuda.device_count()): - cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3 - maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3 - self.stat['cudamem'] = cudamem - self.stat['maxcudamem'] = maxcudamem - - - - def num_delta_parameters(self, module: Optional[nn.Module]=None): - r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to - compute the trainable rate. - - Args: - module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters. - - Returns: - :obj:`List[nn.Parameter]` - """ - if module is None: - module = self.backbone_model - pnum_tot = 0 - for param in module.parameters(): - if hasattr(param, "_is_delta"): - pnum_tot += param.numel() - return pnum_tot # Two functions for plug and remove the delta model. def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True): diff --git a/opendelta/delta_models/adapter.py b/opendelta/delta_models/adapter.py index bbb8bb4..4f191b4 100644 --- a/opendelta/delta_models/adapter.py +++ b/opendelta/delta_models/adapter.py @@ -60,13 +60,14 @@ class AdapterLayer(nn.Module, InterFaceMixin): def get_layer_count(cls): return cls.layer_count - def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None): + def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None, backend="hf"): super().__init__() InterFaceMixin.__init__(self) self.bottleneck_dim = bottleneck_dim self.init_device = device self.instantiated = False self.non_linearity = non_linearity + self.backend=backend self.layer_id = AdapterLayer.get_layer_count() AdapterLayer.count_layer() @@ -79,14 +80,16 @@ class AdapterLayer(nn.Module, InterFaceMixin): else: return self.init_device - def instantiate(self, hidden_dim): + def instantiate(self, hiddens): + self.hidden_dim = hiddens.shape[-1] + self.hidden_dtype = hiddens.dtype self.modulelist = nn.Sequential() - self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.init_device)) + self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.init_device, dtype=self.hidden_dtype)) # select non-linearity self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower())) - self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device)) + self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device, dtype=self.hidden_dtype)) # TODO: # If we want to have a layer norm on output, we apply it later after a separate residual connection @@ -97,11 +100,9 @@ class AdapterLayer(nn.Module, InterFaceMixin): self.instantiated = True # initialize the weight, which is important for fast convergence and better performance. self.apply(self._init_weight) - try: + if self.backend == 'bmt': import bmtrain as bmt self.modulelist = bmt.BMTrainModelWrapper(self.modulelist) - except: - pass def _init_weight(self, module): if isinstance(module, nn.Linear): @@ -123,19 +124,29 @@ class AdapterLayer(nn.Module, InterFaceMixin): raise TypeError hiddens = self._transpose(hiddens) - hiddens = self._convert_data_type(hiddens) + # if self.backend == 'hf': + # hiddens = self._convert_data_type(hiddens) + # elif self.backend == 'bmt': # if bmt, left the convertion to bmt + # pass if not self.instantiated: - self.hidden_dim = hiddens.shape[-1] - logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") - self.instantiate(hidden_dim=self.hidden_dim) - + # self.hidden_dim = hiddens.shape[-1] + # logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") + self.instantiate(hiddens=hiddens) + # from IPython import embed; embed(header="14135315") adapter_output = self.modulelist(hiddens) modified_output = adapter_output + hiddens # TODO option: disable residual_connection modified_output = self._reverse_transpose(modified_output) - modified_output = self._reverse_data_type(modified_output) + + # if self.backend == 'hf': + # # print("!"*100) + # modified_output = self._reverse_data_type(modified_output) + # elif self.backend == 'bmt': # if bmt, left the convertion to bmt + # print("!"*100) + # pass + if isinstance(output, tuple): output = (modified_output,) + output[1:] @@ -189,20 +200,24 @@ class AdapterModel(DeltaBase): modified_modules (:obj:`List[str]`): modules to add adapter after them. unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters. common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping. + backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain. """ config_class = AdapterConfig delta_type = "adapter" default_modified_modules = ["attn@.proj@", "ff@.w2@"] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = True def __init__(self, backbone_model: nn.Module, bottleneck_dim: Optional[int]=24, non_linearity: Optional[str]='gelu_new', - modified_modules: Optional[bool] = None, + modified_modules: Optional[List[str]] = None, + exclude_modules: Optional[List[str]] = None, unfrozen_modules: Optional[bool] = None, common_structure: Optional[bool] = None, interactive_modify: Optional[Union[bool, int]] = False, + backend: Optional[str] = 'hf', ): DeltaBase.__init__(self, backbone_model, @@ -211,6 +226,7 @@ class AdapterModel(DeltaBase): unfrozen_modules=unfrozen_modules, common_structure=common_structure, interactive_modify=interactive_modify, + backend=backend, ) arg_names = get_arg_names_inside_func(self.__init__) for arg_name in arg_names: @@ -231,6 +247,6 @@ class AdapterModel(DeltaBase): def new_module_like(self, module): module_device = get_device(module) - adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device) + adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend) self.delta_modules.append(adapterlayer) return adapterlayer diff --git a/opendelta/delta_models/bitfit.py b/opendelta/delta_models/bitfit.py index 8d26997..9d89548 100644 --- a/opendelta/delta_models/bitfit.py +++ b/opendelta/delta_models/bitfit.py @@ -75,16 +75,6 @@ class BiasLayer(nn.Module): raise TypeError return output -framework_map = {} -framework_map['hf'] = { - "linear": nn.Linear, - "layer_norm": nn.LayerNorm, -} - -framework_map['bmt'] = { - "linear": model_center.layer.Linear, - "layer_norm", model_center.layer.LayerNorm, -} class BitFitModel(DeltaBase): @@ -124,6 +114,7 @@ class BitFitModel(DeltaBase): config_class = BitFitConfig delta_type = "bitfit" default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer. + _supported_backends = ['hf'] _need_pseudo_data = False def __init__(self, backbone_model: nn.Module, @@ -132,7 +123,7 @@ class BitFitModel(DeltaBase): unfrozen_modules: Optional[List[str]] = None, common_structure: Optional[bool] = None, interactive_modify: Optional[Union[bool, int]] = False, - framework_type: Optional[str] = "hf", + backend: Optional[str] = "hf", ): DeltaBase.__init__(self, backbone_model, @@ -141,7 +132,7 @@ class BitFitModel(DeltaBase): unfrozen_modules=unfrozen_modules, common_structure=common_structure, interactive_modify=interactive_modify, - framework_type=framework_type, + backend=backend, ) arg_names = get_arg_names_inside_func(self.__init__) for arg_name in arg_names: @@ -153,6 +144,8 @@ class BitFitModel(DeltaBase): self.add_all_delta_to_backbone(self.backbone_model, self.modified_modules) + + def update_module(self, module: nn.Module, key: str): @@ -167,7 +160,8 @@ class BitFitModel(DeltaBase): # if it is a leaf module, add bias to it regardless of its type. # if self.check_linear(module): # self.add_bias_to_linear(module) - if self.check_linear(module) or self.check_layernorm(module, nn.LayerNorm): + if self.backend_mapping.check_type(module, 'linear') or \ + self.backend_mapping.check_type(module, 'layer_norm'): self.add_bias_to_modules_have_bias_or_known_type(module) else: # for example, layer_norms, lm_heads. @@ -202,48 +196,27 @@ class BitFitModel(DeltaBase): c.bias.requires_grad = True self.delta_params.append(c.bias) else: - if self.check_linear(c) or isinstance(c): # todo: bmt layerNorm + if self.backend_mapping.check_type(c, 'linear') or \ + self.backend_mapping.check_type(c, 'layer_norm'): bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True) - self._reset_bias_parameters(c) #? - try: + self._reset_bias_parameters(c) + if self.backend == 'bmt': import bmtrain as bmt bias = bmt.BMTrainModelWrapper(bias) - except: - pass + c.register_parameter('bias', bias) self.delta_params.append(bias) - def add_bias_to_others(self, c): # todo: bmtrain? - new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) + def add_bias_to_others(self, c): + new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain? + if self.backend == 'bmt': + import bmtrain as bmt + new_bias = bmt.BMTrainModelWrapper(new_bias) + self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm. self.delta_modules.append(new_bias) - def check_linear(self, m): - if isinstance(m, nn.Linear): - return True - else: - try: - from model_center.layer import Linear - if isinstance(m, Linear): - return True - except: - pass - return False - - def check_layernorm(self, m): - if isinstance(m, nn.LayerNorm): - return True - else: - try: - from model_center.layer import LayerNorm - if isinstance(m, LayerNorm): - return True - except: - pass - return False - - @staticmethod def _reset_bias_parameters(linear_module): fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight) diff --git a/opendelta/delta_models/compacter.py b/opendelta/delta_models/compacter.py index ca88bf2..9743aca 100644 --- a/opendelta/delta_models/compacter.py +++ b/opendelta/delta_models/compacter.py @@ -36,6 +36,7 @@ class HyperComplexAdapterLayer(nn.Module): device=None, use_bias_up_sampler=True, use_bias_down_sampler=True, + backend = 'hf', ): super().__init__() self.reduction_factor = reduction_factor @@ -55,14 +56,17 @@ class HyperComplexAdapterLayer(nn.Module): self.use_bias_up_sampler=use_bias_up_sampler self.use_bias_down_sampler=use_bias_down_sampler self.device = device + self.backend = backend self.instantiated = False - def instantiate(self, hidden_dim): - self.down_sample_size = hidden_dim // self.reduction_factor + def instantiate(self, hiddens): + self.hidden_dim = hiddens.shape[-1] + self.hidden_dtype = hiddens.dtype + self.down_sample_size = self.hidden_dim // self.reduction_factor self.activation = Activations(self.non_linearity.lower()).to(self.device) - self.down_sampler = PHMLinear(in_features=hidden_dim, + self.down_sampler = PHMLinear(in_features=self.hidden_dim, out_features=self.down_sample_size, bias=self.use_bias_down_sampler, c_init=self.phm_c_init, @@ -76,9 +80,10 @@ class HyperComplexAdapterLayer(nn.Module): factorized_phm_rule=self.factorized_phm_rule, phm_rank=self.phm_rank, phm_init_range=self.phm_init_range, - kronecker_prod=self.kronecker_prod).to(self.device) + kronecker_prod=self.kronecker_prod, + dtype = self.hidden_dtype).to(self.device) self.up_sampler = PHMLinear(in_features=self.down_sample_size, - out_features=hidden_dim, + out_features=self.hidden_dim, bias=self.use_bias_up_sampler, c_init=self.phm_c_init, phm_dim=self.hypercomplex_division, @@ -91,15 +96,14 @@ class HyperComplexAdapterLayer(nn.Module): factorized_phm_rule=self.factorized_phm_rule, phm_rank=self.phm_rank, phm_init_range=self.phm_init_range, - kronecker_prod=self.kronecker_prod).to(self.device) + kronecker_prod=self.kronecker_prod, + dtype = self.hidden_dtype).to(self.device) self.instantiated = True - try: + if self.backend == "bmt": import bmtrain as bmt self.activation = bmt.BMTrainModelWrapper(self.activation) self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler) self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler) - except: - pass def post_forward(self, output): @@ -116,9 +120,7 @@ class HyperComplexAdapterLayer(nn.Module): raise TypeError if not self.instantiated: - self.hidden_dim = hiddens.shape[-1] - logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") - self.instantiate(hidden_dim=self.hidden_dim) + self.instantiate(hiddens=hiddens) z = self.down_sampler(hiddens) @@ -193,6 +195,7 @@ class CompacterModel(DeltaBase): unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the prefix parameters. common_structure (:obj:`bool`, *optional*, default to :obj:`None`): whether using name-based addressing with a common structure mapping. + backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain reduction_factor (:obj:`int`, *optional*, default to ``16``): bottleneck_dim = hidden_dim//reduction_factor non_linearity (:obj:`str`, *optional*, default to ``"gelu_new"``): The non linearity activation used in between the down projecter and the up projecter. @@ -218,6 +221,7 @@ class CompacterModel(DeltaBase): config_class = CompacterConfig delta_type = "compacter" default_modified_modules = ["attn@.proj@", "ff@.w2@"] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = True def __init__(self, backbone_model, @@ -226,6 +230,7 @@ class CompacterModel(DeltaBase): unfrozen_modules: Optional[List[str]] = None, common_structure: Optional[bool] = None, interactive_modify: Optional[Union[bool, int]] = False, + backend: Optional[str] = 'hf', reduction_factor=16, non_linearity="gelu_new", phm_c_init="normal", @@ -288,22 +293,6 @@ class CompacterModel(DeltaBase): def new_module_like(self, module): module_device = get_device(module) - adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor, - non_linearity=self.non_linearity, - phm_c_init=self.phm_c_init, - hypercomplex_division=self.hypercomplex_division, - learn_phm=self.learn_phm, - hypercomplex_nonlinearity=self.hypercomplex_nonlinearity, - shared_phm_rule=self.shared_phm_rule, - factorized_phm=self.factorized_phm, - shared_W_phm=self.shared_W_phm, - factorized_phm_rule=self.factorized_phm_rule, - phm_rank=self.phm_rank, - phm_init_range=self.phm_init_range, - kronecker_prod=self.kronecker_prod, - use_bias_up_sampler=self.use_bias_up_sampler, - use_bias_down_sampler=self.use_bias_down_sampler, - device=module_device - ) + adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor, non_linearity=self.non_linearity, phm_c_init=self.phm_c_init, hypercomplex_division=self.hypercomplex_division, learn_phm=self.learn_phm, hypercomplex_nonlinearity=self.hypercomplex_nonlinearity, shared_phm_rule=self.shared_phm_rule, factorized_phm=self.factorized_phm, shared_W_phm=self.shared_W_phm, factorized_phm_rule=self.factorized_phm_rule, phm_rank=self.phm_rank, phm_init_range=self.phm_init_range, kronecker_prod=self.kronecker_prod, use_bias_up_sampler=self.use_bias_up_sampler, use_bias_down_sampler=self.use_bias_down_sampler, device=module_device, backend=self.backend) self.delta_modules.append(adapterlayer) return adapterlayer diff --git a/opendelta/delta_models/layers/hypercomplex_linear.py b/opendelta/delta_models/layers/hypercomplex_linear.py index acee0e8..e4f94c2 100644 --- a/opendelta/delta_models/layers/hypercomplex_linear.py +++ b/opendelta/delta_models/layers/hypercomplex_linear.py @@ -84,7 +84,8 @@ class PHMLinear(torch.nn.Module): factorized_phm_rule=False, phm_rank = 1, phm_init_range=0.0001, - kronecker_prod=False) -> None: + kronecker_prod=False, + dtype=torch.float) -> None: super(PHMLinear, self).__init__() assert w_init in ["phm", "glorot-normal", "glorot-uniform", "normal"] assert c_init in ["normal", "uniform"] @@ -104,12 +105,12 @@ class PHMLinear(torch.nn.Module): self.factorized_phm_rule = factorized_phm_rule if not self.shared_phm_rule: if self.factorized_phm_rule: - self.phm_rule_left = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, 1), + self.phm_rule_left = nn.Parameter(torch.empty((phm_dim, phm_dim, 1), dtype=dtype), requires_grad=learn_phm) - self.phm_rule_right = nn.Parameter(torch.FloatTensor(phm_dim, 1, phm_dim), + self.phm_rule_right = nn.Parameter(torch.empty((phm_dim, 1, phm_dim), dtype=dtype), requires_grad=learn_phm) else: - self.phm_rule = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, phm_dim), + self.phm_rule = nn.Parameter(torch.empty((phm_dim, phm_dim, phm_dim), dtype=dtype), requires_grad=learn_phm) self.bias_flag = bias self.w_init = w_init @@ -118,15 +119,15 @@ class PHMLinear(torch.nn.Module): self.factorized_phm = factorized_phm if not self.shared_W_phm: if self.factorized_phm: - self.W_left = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self.phm_rank)), + self.W_left = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self.phm_rank), dtype=dtype), requires_grad=True) - self.W_right = nn.Parameter(torch.Tensor(size=(phm_dim, self.phm_rank, self._out_feats_per_axis)), + self.W_right = nn.Parameter(torch.empty((phm_dim, self.phm_rank, self._out_feats_per_axis), dtype=dtype), requires_grad=True) else: - self.W = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self._out_feats_per_axis)), + self.W = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self._out_feats_per_axis), dtype=dtype), requires_grad=True) if self.bias_flag: - self.b = nn.Parameter(torch.Tensor(out_features)) + self.b = nn.Parameter(torch.empty(out_features, dtype=dtype), requires_grad=True) else: self.register_parameter("b", None) self.reset_parameters() diff --git a/opendelta/delta_models/layers/low_rank_linear.py b/opendelta/delta_models/layers/low_rank_linear.py index bb2b25d..95c1466 100644 --- a/opendelta/delta_models/layers/low_rank_linear.py +++ b/opendelta/delta_models/layers/low_rank_linear.py @@ -6,17 +6,17 @@ from opendelta.delta_models.layers.init import glorot_uniform, glorot_normal class LowRankLinear(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, rank: int = 1, - bias: bool = True, w_init: str = "glorot-uniform"): + bias: bool = True, w_init: str = "glorot-uniform", dtype=torch.float): super(LowRankLinear, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.rank = rank self.bias = bias self.w_init = w_init - self.W_left = nn.Parameter(torch.Tensor(size=(input_dim, rank)), requires_grad=True) - self.W_right = nn.Parameter(torch.Tensor(size=(rank, output_dim)), requires_grad=True) + self.W_left = nn.Parameter(torch.empty((input_dim, rank), dtype=dtype),requires_grad=True) + self.W_right = nn.Parameter(torch.empty((rank, output_dim), dtype=dtype), requires_grad=True) if bias: - self.b = nn.Parameter(torch.Tensor(output_dim)) + self.b = nn.Parameter(torch.empty(output_dim, dtype=dtype)) self.reset_parameters() def reset_parameters(self): diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 83bd65a..0446842 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -97,12 +97,14 @@ class LoraModel(DeltaBase): unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the prefix parameters. common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping. + backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain """ config_class = LoraConfig delta_type = "lora" default_modified_modules = ['attn@.q@', 'attn@.v@'] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = False def __init__(self, backbone_model: nn.Module, @@ -114,6 +116,7 @@ class LoraModel(DeltaBase): exclude_modules: Optional[List[str]] = None, common_structure: Optional[bool] = None, interactive_modify: Optional[Union[bool, int]] = False, + backend: Optional[str] = "hf", ): DeltaBase.__init__(self, backbone_model, @@ -121,6 +124,7 @@ class LoraModel(DeltaBase): unfrozen_modules=unfrozen_modules, common_structure=common_structure, interactive_modify=interactive_modify, + backend=backend, ) arg_names = get_arg_names_inside_func(self.__init__) for arg_name in arg_names: @@ -151,10 +155,9 @@ class LoraModel(DeltaBase): r=self.lora_r, lora_alpha=self.lora_alpha, lora_dropout=self.lora_dropout) - try: + if self.backend == "bmt": import bmtrain as bmt new_module = bmt.BMTrainModelWrapper(new_module) - except: - pass + self.delta_modules.append(new_module) return new_module diff --git a/opendelta/delta_models/low_rank_adapter.py b/opendelta/delta_models/low_rank_adapter.py index 5946331..eaef90a 100644 --- a/opendelta/delta_models/low_rank_adapter.py +++ b/opendelta/delta_models/low_rank_adapter.py @@ -47,7 +47,8 @@ class LowRankAdapter(nn.Module): non_linearity="gelu_new", low_rank_w_init="glorot-uniform", low_rank_rank=1, - device=None): + device=None, + backend='hf'): super().__init__() self.reduction_factor = reduction_factor self.non_linearity = non_linearity @@ -55,27 +56,31 @@ class LowRankAdapter(nn.Module): self.low_rank_rank = low_rank_rank self.device = device self.instantiated = False + self.backend=backend - def instantiate(self, hidden_dim): + def instantiate(self, hiddens): + self.hidden_dim = hiddens.shape[-1] + self.hidden_dtype = hiddens.dtype - self.down_sample_size = hidden_dim // self.reduction_factor + self.down_sample_size = self.hidden_dim // self.reduction_factor self.activation = Activations(self.non_linearity.lower()).to(self.device) - self.down_sampler = LowRankLinear(hidden_dim, self.down_sample_size, + self.down_sampler = LowRankLinear(self.hidden_dim, self.down_sample_size, w_init=self.low_rank_w_init, - rank=self.low_rank_rank).to(self.device) - self.up_sampler = LowRankLinear(self.down_sample_size, hidden_dim, + rank=self.low_rank_rank, + dtype=self.hidden_dtype).to(self.device) + self.up_sampler = LowRankLinear(self.down_sample_size, self.hidden_dim, w_init=self.low_rank_w_init, - rank=self.low_rank_rank).to(self.device) + rank=self.low_rank_rank, + dtype=self.hidden_dtype).to(self.device) self.instantiated = True - try: + if self.backend == 'bmt': import bmtrain as bmt self.activation = bmt.BMTrainModelWrapper(self.activation) self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler) self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler) - except: - pass + def post_forward(self, output): r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter, @@ -91,9 +96,7 @@ class LowRankAdapter(nn.Module): raise TypeError if not self.instantiated: - self.hidden_dim = hiddens.shape[-1] - logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") - self.instantiate(hidden_dim=self.hidden_dim) + self.instantiate(hiddens = hiddens) z = self.down_sampler(hiddens) z = self.activation(z) @@ -154,6 +157,7 @@ class LowRankAdapterModel(DeltaBase): config_class = LowRankAdapterConfig delta_type = "low_rank_adapter" default_modified_modules = ["attn@.proj@", "ff@.w2@"] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = True def __init__(self, backbone_model: nn.Module, @@ -166,6 +170,7 @@ class LowRankAdapterModel(DeltaBase): unfrozen_modules: Optional[List[str]] = None, common_structure: Optional[bool] = None, interactive_modify: Optional[Union[bool, int]] = False, + backend: Optional[str] = 'hf', ): DeltaBase.__init__(self, backbone_model, @@ -174,6 +179,7 @@ class LowRankAdapterModel(DeltaBase): unfrozen_modules=unfrozen_modules, common_structure=common_structure, interactive_modify=interactive_modify, + backend=backend, ) arg_names = get_arg_names_inside_func(self.__init__) for arg_name in arg_names: @@ -209,6 +215,6 @@ class LowRankAdapterModel(DeltaBase): non_linearity = self.non_linearity, low_rank_w_init = self.low_rank_w_init, low_rank_rank = self.low_rank_rank, - device=module_device) + device=module_device, backend=self.backend) self.delta_modules.append(adapterlayer) return adapterlayer diff --git a/opendelta/delta_models/parallel_adapter.py b/opendelta/delta_models/parallel_adapter.py index 1024394..d354587 100644 --- a/opendelta/delta_models/parallel_adapter.py +++ b/opendelta/delta_models/parallel_adapter.py @@ -25,30 +25,36 @@ class ParallelAdapterLayer(nn.Module): def get_layer_count(cls): return cls.layer_count - def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None): + def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None, backend='hf'): super().__init__() self.bottleneck_dim = bottleneck_dim self.device = device self.instantiated = False self.non_linearity = non_linearity self.scaled = scaled + self.backend = backend self.layer_id = ParallelAdapterLayer.get_layer_count() ParallelAdapterLayer.count_layer() - def instantiate(self, hidden_dim): + def instantiate(self, hiddens): + self.hidden_dim = hiddens.shape[-1] + self.hidden_dtype = hiddens.dtype self.modulelist = nn.Sequential() - self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.device)) + self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.device, dtype=self.hidden_dtype)) # select non-linearity self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower())) - self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device)) + self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device, dtype=self.hidden_dtype)) self.instantiated = True # initialize the weight, which is important for fast convergence and better performance. self.apply(self._init_weight) + if self.backend == 'bmt': + import bmtrain as bmt + self.modulelist = bmt.BMTrainModelWrapper(self.modulelist) def _init_weight(self, module): if isinstance(module, nn.Linear): @@ -71,9 +77,8 @@ class ParallelAdapterLayer(nn.Module): if not self.instantiated: - self.hidden_dim = hiddens.shape[-1] - logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") - self.instantiate(hidden_dim=self.hidden_dim) + # logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}") + self.instantiate(hiddens = hiddens) self.adapter_output = self.modulelist(hiddens) * self.scaled @@ -141,12 +146,14 @@ class ParallelAdapterModel(DeltaBase): modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output. unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters. common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping. + backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain """ config_class = ParallelAdapterConfig delta_type = "parallel_adapter" default_modified_modules = ["attn@", "attn@", "ff@.w1@", "ff@.w2@"] # default_modified_modules = ["attn", "attn", "ff.w1", "ff.w2"] + _supported_backends = ['hf', 'bmt'] _need_pseudo_data = True def __init__(self, backbone_model: nn.Module, @@ -156,7 +163,8 @@ class ParallelAdapterModel(DeltaBase): exclude_modules: Optional[List[str]] = None, unfrozen_modules: Optional[bool] = None, common_structure: Optional[bool] = None, - interactive_modify: Optional[Union[bool, int]] = False, + interactive_modify: Optional[Union[bool, int]] = False, + backend: Optional[str] = "hf", ): DeltaBase.__init__(self, backbone_model, @@ -165,6 +173,7 @@ class ParallelAdapterModel(DeltaBase): unfrozen_modules=unfrozen_modules, common_structure=common_structure, interactive_modify=interactive_modify, + backend=backend, ) arg_names = get_arg_names_inside_func(self.__init__) for arg_name in arg_names: @@ -193,7 +202,7 @@ class ParallelAdapterModel(DeltaBase): def new_module_like(self, module): module_device = get_device(module) - adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device) + adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend) self.delta_modules.append(adapterlayer) return adapterlayer \ No newline at end of file diff --git a/opendelta/delta_models/prefix.py b/opendelta/delta_models/prefix.py index f64df2c..777a8e1 100644 --- a/opendelta/delta_models/prefix.py +++ b/opendelta/delta_models/prefix.py @@ -516,6 +516,7 @@ class PrefixModel(DeltaBase): config_class = PrefixConfig delta_type = "prefix" default_modified_modules = ['attn@'] + _supported_backends = ['hf'] _need_pseudo_data = True def __init__(self, backbone_model: nn.Module, diff --git a/opendelta/delta_models/soft_prompt.py b/opendelta/delta_models/soft_prompt.py index ff2346d..6453368 100644 --- a/opendelta/delta_models/soft_prompt.py +++ b/opendelta/delta_models/soft_prompt.py @@ -161,6 +161,7 @@ class SoftPromptModel(DeltaBase): config_class = SoftPromptConfig delta_type = "soft_prompt" default_modified_modules = ["root"] # not used + _supported_backends = ['hf'] #'bmt'] _need_pseudo_data = False def __init__(self, backbone_model: nn.Module, diff --git a/opendelta/utils/backend.py b/opendelta/utils/backend.py new file mode 100644 index 0000000..0b5b124 --- /dev/null +++ b/opendelta/utils/backend.py @@ -0,0 +1,110 @@ + + +import importlib + + +class BackendMapping: + """ + " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. + + Args: + + - config_mapping: The map model type to config class + - model_mapping: The map model type to model (or tokenizer) class + """ + + def __init__(self, backend): + self.backend = backend + assert backend in ['hf', 'bmt'], "Backend should be one of 'hf', 'bmt'. " + if backend == 'hf': + self.backend_mapping = { + "linear": "torch.nn.Linear", + "layer_norm": "torch.nn.LayerNorm", + "module": "torch.nn.Module", + "parameter": "torch.nn.Parameter" + } + elif backend == 'bmt': + self.backend_mapping = { + "linear": "model_center.layer.Linear", + "layer_norm": "model_center.layer.LayerNorm", + "module": "bmtrain.layer.DistributedModule", + "parameter": "bmtrain.nn.DistributedParameter" + } + self.registered = {} + + def load(self, model_type): + if model_type not in self.registered: + splited = self.backend_mapping[model_type].split(".") + module_name, class_name = ".".join(splited[:-1]), splited[-1] + module = importlib.import_module(module_name) + the_class = getattr(module, class_name) + self.registered[model_type] = the_class + return self.registered[model_type] + + def check_type(self, module, expect_type): + the_class = self.load(expect_type) + if isinstance(module, the_class): + return True + else: + return False + + + # def keys(self): + # mapping_keys = [ + # self._load_attr_from_module(key, name) + # for key, name in self._config_mapping.items() + # if key in self._model_mapping.keys() + # ] + # return mapping_keys + list(self._extra_content.keys()) + + # def get(self, key, default): + # try: + # return self.__getitem__(key) + # except KeyError: + # return default + + # def __bool__(self): + # return bool(self.keys()) + + # def values(self): + # mapping_values = [ + # self._load_attr_from_module(key, name) + # for key, name in self._model_mapping.items() + # if key in self._config_mapping.keys() + # ] + # return mapping_values + list(self._extra_content.values()) + + # def items(self): + # mapping_items = [ + # ( + # self._load_attr_from_module(key, self._config_mapping[key]), + # self._load_attr_from_module(key, self._model_mapping[key]), + # ) + # for key in self._model_mapping.keys() + # if key in self._config_mapping.keys() + # ] + # return mapping_items + list(self._extra_content.items()) + + # def __iter__(self): + # return iter(self.keys()) + + # def __contains__(self, item): + # if item in self._extra_content: + # return True + # if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: + # return False + # model_type = self._reverse_config_mapping[item.__name__] + # return model_type in self._model_mapping + + # def register(self, key, value): + # """ + # Register a new model in this mapping. + # """ + # if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: + # model_type = self._reverse_config_mapping[key.__name__] + # if model_type in self._model_mapping.keys(): + # raise ValueError(f"'{key}' is already used by a Transformers model.") + + # self._extra_content[key] = value + + diff --git a/opendelta/utils/inspect.py b/opendelta/utils/inspect.py new file mode 100644 index 0000000..830298e --- /dev/null +++ b/opendelta/utils/inspect.py @@ -0,0 +1,112 @@ + +import torch +import torch.nn as nn +from typing import Optional +import opendelta.utils.logging as logging + +logger = logging.get_logger(__name__) + + +def inspect_module_statistics(module: Optional[nn.Module]=None, verbose=True): + r"""Get the statistics of the parameters in the delta modules. + + Args: + module (:obj:`nn.Module`, *optional*): The module to compute the statistics. + + Returns: + :obj:`dict`: The statistics of the parameters in the delta modules. + + """ + + stat = {} + n_trainable = num_trainable_parameters(module) + n_total = num_total_parameters(module) + + stat['total_parameters'] = n_total + stat['trainable_parameters'] = n_trainable + + stat['trainable_ratio'] = n_trainable/n_total + + n_delta = num_delta_parameters(module) + n_total = num_total_parameters(module) + stat['delta_parameters'] = n_delta + stat['delta_ratio'] = n_delta/n_total + + cudamem = 0 + maxcudamem = 0 + for device_id in range(torch.cuda.device_count()): + cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3 + maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3 + stat['cudamem'] = cudamem + stat['maxcudamem'] = maxcudamem + + if verbose: + logger.info(stat) + + return stat + +def num_trainable_parameters(module: Optional[nn.Module]=None): + r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to + compute the trainable rate. + + Args: + module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters. + + Returns: + :obj:`List[nn.Parameter]` + """ + pnum_tot = 0 + for param in module.parameters(): + if param.requires_grad: + pnum_tot += param.numel() + return pnum_tot + + +def num_total_parameters(module: Optional[nn.Module]=None): + r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to + compute the trainable rate. + + Args: + module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters. + + Returns: + :obj:`List[nn.Parameter]` + """ + pnum_tot = 0 + for param in module.parameters(): + pnum_tot += param.numel() + return pnum_tot + +def num_delta_parameters(module: Optional[nn.Module]=None): + r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to + compute the trainable rate. + + Args: + module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters. + + Returns: + :obj:`List[nn.Parameter]` + """ + pnum_tot = 0 + for param in module.parameters(): + if hasattr(param, "_is_delta"): + pnum_tot += param.numel() + return pnum_tot + +def inspect_optimizer_statistics(optimizer, verbose=True): + stats = {} + for id, param_group in enumerate(optimizer.param_groups): + stat = {} + fine_grain_info = [(p.numel(), p.requires_grad) for p in param_group['params']] + stat['total_parameters'] = sum(n for n, r in fine_grain_info) + stat['trainable_parameters'] = sum(n for n, r in fine_grain_info if r) + stat['trainable_ratio'] = "{:.6f}%".format(stat['trainable_parameters']/stat['total_parameters']*100) + for key in param_group: + if key != 'params': + stat[key] = param_group[key] + stats[f'param_group_{id}'] = stat + + if verbose: + logger.info(f"optimizer info: {stats}") + + return stat diff --git a/setup.py b/setup.py index 27c0313..711f5a4 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def get_requirements(): with open('README.md', 'r') as f: setuptools.setup( name = 'opendelta', - version = "0.3.1", + version = "0.3.2", description = "An open source framework for delta learning (parameter efficient learning).", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown",