support bmtrain

This commit is contained in:
shengdinghu 2022-10-23 08:42:21 +00:00
parent e0de6b02ad
commit ce7f992864
24 changed files with 734 additions and 485 deletions

5
.gitignore vendored
View File

@ -54,8 +54,13 @@ t.sh
**/delta_checkpoints/
**/outputs/
dist/*
**/unittest/**
!unittest/**.py
!unittest/**.sh
**/tutorial/**
!tutorial/**.py
!tutorial/**.sh
!tutorial/**.md

View File

@ -31,8 +31,8 @@ copyright = '{}, {}, Licenced under the Apache License, Version 2.0'.format(date
# The full version, including alpha/beta/rc tags
release = '0.3.1'
version = "0.3.1"
release = '0.3.2'
version = "0.3.2"
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]

View File

@ -1,6 +1,14 @@
(acceleration)=
# OpenDelta+
<img src="../imgs/todo-icon.jpeg" height="30px"> We are working on testing and improving the functionality with work with other acceleration packages for model training and inference. For example, [deepspeed](https://github.com/microsoft/DeepSpeed), [BMInf](https://github.com/OpenBMB/BMInf).
Feel free to contact us via email (shengdinghu@gmail.com) if you have any suggestion.
## BMTrain
- [BMTrain](https://github.com/OpenBMB/BMTrain) is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
- [ModelCenter](https://github.com/OpenBMB/ModelCenter) implements pre-trained language models (PLMs) based on the backend OpenBMB/BMTrain. ModelCenter supports Efficient, Low-Resource, Extendable model usage and distributed training.
Now we have the LoraModel, AdapterModel, CompacterModel, ParallelAdapterModel, LowRankAdapterModel fully supported the distributed training with BMTrain and ModelCenter. Please try is out in
## Huggingface Accelerate
<img src="../imgs/todo-icon.jpeg" height="30px">

View File

@ -1,5 +1,10 @@
# Update Logs and Known Issues
## Version 0.3.2
- We support BMTrain to accelerate the training, and parallelize the training of models that are hard to fit in a single GPU. Check [tutorial/2_with_bmtrain.py](https://github.com/thunlp/OpenDelta/tree/main/examples/tutorial/2_with_bmtrain.py)
- We add a functionality to [inspect the optimizer](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py). The user can see the number of trainable parameters in the optimizer and verify that opendelta is being used correctly.
- We move the functions to inspect the delta models into [inspect.py](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py)
## Version 0.3.1
- We update [must_try.py](https://github.com/thunlp/OpenDelta/tree/main/examples/unittest/must_try.py) for a simple introduction of the core functionality of OpenDelta.
- Thanks to [Weilin Zhao](https://github.com/Achazwl) We merge a long-developed branch parallel_adapter into the main branch.

View File

@ -1,50 +1,291 @@
import bmtrain as bmt
import opendelta as od
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel
# adapted from https://github.com/OpenBMB/ModelCenter/blob/main/examples/bert/finetune_bert.py
import time
import os
import torch
import numpy
import random
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, f1_score
def manual_seed(seed):
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
import bmtrain as bmt
from model_center.model import Bert, BertConfig
bmt.init_distributed()
config = BertConfig.from_pretrained("bert-base-uncased")
config.dropout_p = 0
model = Bert.from_pretrained("bert-base-uncased", config)
from model_center import get_args
from model_center.model import Bert
from model_center.tokenizer import BertTokenizer
from model_center.dataset.bertdataset import DATASET
from model_center.utils import print_inspect
from model_center.layer import Linear
from model_center.dataset import DistributedDataLoader
import opendelta as od
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
from opendelta.utils.inspect import inspect_optimizer_statistics
print("before modify")
od.Visualization(model).structure_graph()
manual_seed(233)
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'])
# delta_model = AdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
# delta_model = CompacterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
# delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
# delta_model = BitFitModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn', '[r](.*)layernorm(.*)'])
class BertModel(torch.nn.Module):
def __init__(self, args, num_types):
super().__init__()
self.bert : Bert = Bert.from_pretrained(args.model_config)
dim_model = self.bert.input_embedding.dim_model
self.dense = Linear(dim_model, num_types)
bmt.init_parameters(self.dense)
# print(delta_model.delta_modules)
def forward(self, *args, **kwargs):
pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output
logits = self.dense(pooler_output)
return logits
print("after modify")
delta_model.log()
# This will visualize the backbone after modification and other information.
def get_tokenizer(args):
tokenizer = BertTokenizer.from_pretrained(args.model_config)
return tokenizer
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
print("after freeze")
delta_model.log()
# The set_state_dict=True will tell the method to change the state_dict of the backbone_model to maintaining only the trainable parts.
def get_model(args):
num_types = {
"BoolQ" : 2,
"CB" : 3,
"COPA" : 1,
"RTE" : 2,
"WiC" : 2,
}
model = BertModel(args, num_types[args.dataset_name])
od.Visualization(model).structure_graph()
manual_seed(233)
inp = torch.randint(0, 30000, (32, 128)).cuda()
length = torch.randint(0, 128, (32,)).cuda()
attention_mask = (torch.arange(inp.shape[1], device=inp.device)[None, :].repeat(inp.shape[0], 1) < length[:, None])
out = model(inp, attention_mask=attention_mask, output_logits=True).logits
print(out)
if bmt.rank() == 0:
torch.save(model.state_dict(), "test.pt")
ckpt = torch.load("test.pt")
print(ckpt.keys())
if args.delta_type == "lora":
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
elif args.delta_type == "bitfit":
delta_model = BitFitModel(backbone_model=model, modified_modules=['self_att', 'ffn', 'layernorm'], backend='bmt') #TODO: fix bug
elif args.delta_type == "adapter":
delta_model = AdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
elif args.delta_type == "compacter":
delta_model = CompacterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
elif args.delta_type == "low_rank_adapter":
delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
elif args.delta_type == "parallel_adapter":
delta_model = ParallelAdapterModel(backbone_model=model, modified_modules=['self_att', 'self_att', 'ffn.ffn', 'ffn.ffn'], backend='bmt')
print("after modify")
delta_model.log()
# This will visualize the backbone after modification and other information.
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
print("after freeze")
delta_model.log()
return model
def get_optimizer(args, model):
optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=args.weight_decay)
return optimizer
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters * args.epochs
if args.lr_decay_style == "noam":
lr_scheduler = bmt.lr_scheduler.Noam(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
elif args.lr_decay_style == "constant":
lr_scheduler = bmt.lr_scheduler.NoDecay(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = -1,
num_iter = args.start_step)
elif args.lr_decay_style == "linear":
lr_scheduler = bmt.lr_scheduler.Linear(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
elif args.lr_decay_style == "exponential":
lr_scheduler = bmt.lr_scheduler.Exponential(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
elif args.lr_decay_style == "cosine":
lr_scheduler = bmt.lr_scheduler.Cosine(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
else:
raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.")
return lr_scheduler
def setup_model_and_optimizer(args):
# get the tokenizer
tokenizer = get_tokenizer(args)
# get the model
model = get_model(args)
bmt.synchronize()
# get the optimizer and lr_scheduler
optimizer = get_optimizer(args, model)
inspect_optimizer_statistics(optimizer)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
# get the memory usage
bmt.print_rank("Model mem\n", torch.cuda.memory_summary())
bmt.synchronize()
return tokenizer, model, optimizer, lr_scheduler
def initialize():
# get arguments
args = get_args()
# init bmt
bmt.init_distributed(seed = args.seed)
# init save folder
if args.save != None:
os.makedirs(args.save, exist_ok=True)
return args
def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size):
splits = ['train', 'dev', 'test']
dataset = {}
for split in splits:
dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length)
return dataset
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
optim_manager.add_optimizer(optimizer, lr_scheduler)
# print_inspect(model, '*') # too much output
for epoch in range(12):
dataloader = {
"train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True),
"dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False),
}
model.train()
for it, data in enumerate(dataloader['train']):
if args.dataset_name == 'COPA':
input_ids0 = data["input_ids0"]
attention_mask0 = data["attention_mask0"]
token_type_ids0 = data["token_type_ids0"]
input_ids1 = data["input_ids1"]
attention_mask1 = data["attention_mask1"]
token_type_ids1 = data["token_type_ids1"]
labels = data["labels"]
else:
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = data["labels"]
torch.cuda.synchronize()
st_time = time.time()
if args.dataset_name == 'COPA':
logits = torch.cat([
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
], dim=1)
else:
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
global_loss = bmt.sum_loss(loss).item()
optim_manager.zero_grad()
optim_manager.backward(loss)
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type = 2)
optim_manager.step()
torch.cuda.synchronize()
elapsed_time = time.time() - st_time
# from IPython import embed; embed(header="25252")
bmt.print_rank(
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format(
epoch,
it,
len(dataloader["train"]),
global_loss,
lr_scheduler.current_lr,
int(optim_manager.loss_scale),
grad_norm,
elapsed_time,
)
)
model.eval()
with torch.no_grad():
for split in ['dev']:
pd = []
gt = []
for it, data in enumerate(dataloader[split]):
if args.dataset_name == 'COPA':
input_ids0 = data["input_ids0"]
attention_mask0 = data["attention_mask0"]
token_type_ids0 = data["token_type_ids0"]
input_ids1 = data["input_ids1"]
attention_mask1 = data["attention_mask1"]
token_type_ids1 = data["token_type_ids1"]
labels = data["labels"]
logits = torch.cat([
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
], dim=1)
else:
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = data["labels"]
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
logits = logits.argmax(dim=-1)
pd.extend(logits.cpu().tolist())
gt.extend(labels.cpu().tolist())
bmt.print_rank(
"{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format(
split,
epoch,
it,
len(dataloader[split]),
loss,
)
)
pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist()
gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist()
bmt.print_rank(f"{split} epoch {epoch}:")
if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]:
acc = accuracy_score(gt, pd)
bmt.print_rank(f"accuracy: {acc*100:.2f}")
if args.dataset_name in ["CB"]:
rcl = f1_score(gt, pd, average="macro")
f1 = recall_score(gt, pd, average="macro")
bmt.print_rank(f"recall: {rcl*100:.2f}")
bmt.print_rank(f"Average F1: {f1*100:.2f}")
def main():
args = initialize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
dataset = prepare_dataset(
args,
tokenizer,
f"{args.base_path}/down_data/superglue/",
args.dataset_name,
bmt.rank(), bmt.world_size(),
)
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
if __name__ == "__main__":
main()

View File

@ -1 +1,37 @@
python3 -m torch.distributed.launch --master_addr localhost --master_port 34123 --nproc_per_node $1 --nnodes 1 --node_rank 0 2_with_bmtrain.py
#! /bin/bash
MASTER_ADDR=localhost
MASTER_PORT=12345
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=4
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
BASE_PATH="./"
VERSION="bert-large-cased"
DATASET="BoolQ" # You can try other dataset listed in https://github.com/OpenBMB/ModelCenter/tree/main/examples/bert
OPTS=""
OPTS+=" --model-config ${VERSION}"
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --dataset_name ${DATASET}"
OPTS+=" --batch-size 64"
OPTS+=" --lr 0.001" # You can use different learning rate to find optimal performance
OPTS+=" --max-encoder-length 512"
OPTS+=" --train-iters 1400"
OPTS+=" --lr-decay-style constant"
OPTS+=" --weight-decay 1e-2"
OPTS+=" --clip-grad 10.0"
OPTS+=" --loss-scale 128"
OPTS+=" --delta_type low_rank_adapter" # You can use different delta type, listed in https://opendelta.readthedocs.io/en/latest/notes/acceleration.html#BMTrain
CMD="python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} ${BASE_PATH}2_with_bmtrain.py ${OPTS}"
echo ${CMD}
${CMD} 2>&1 | tee ${BASE_PATH}/tmp/logs/bmt_bert_boolq_finetune-${VERSION}-${DATASET}.log

View File

@ -13,4 +13,23 @@ requirement:
```
pip install openprompt
```
```
## 2_with_bmtrain.py
1. install necessary packages:
```
pip install git+https://github.com/OpenBMB/BMTrain.git
pip install git+git@github.com:OpenBMB/ModelCenter.git
```
2. download dataset from https://super.gluebenchmark.com/tasks, e.g.,
```
mkdir down_data
cd down_data
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/BoolQ.zip
unzip BoolQ.zip
```
3. Run the shell scripts, change `NNODES`,`GPUS_PER_NODE` according to your computational resources.
```
bash 2_with_bmtrain.sh
```

View File

@ -67,6 +67,16 @@ delta2.detach()
# say we add lora to the last four layer of the decoder of t5, with lora rank=5
delta_config3 = AutoDeltaConfig.from_dict({"delta_type":"lora", "modified_modules":["[r]decoder.*((20)|(21)|(22)|(23)).*DenseReluDense\.wi"], "lora_r":5})
delta3 = AutoDeltaModel.from_config(delta_config3, backbone_model=wrapped_model)
delta3.freeze_module()
delta3.log()
# add optimizer as normal
from transformers import AdamW
optimizer = AdamW(wrapped_model.parameters(), lr=3e-3)
# inspect_optimizer
from opendelta.utils.inspect import inspect_optimizer_statistics
inspect_optimizer_statistics(optimizer)

View File

@ -1,255 +0,0 @@
import time
import os
import torch
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, f1_score
import bmtrain as bmt
from model_center import get_args
from model_center.model import Bert
from model_center.tokenizer import BertTokenizer
from model_center.dataset.bertdataset import DATASET
from model_center.utils import print_inspect
from model_center.layer import Linear
from model_center.dataset import DistributedDataLoader
class BertModel(torch.nn.Module):
def __init__(self, args, num_types):
super().__init__()
self.bert : Bert = Bert.from_pretrained(args.model_config)
dim_model = self.bert.input_embedding.dim_model
self.dense = Linear(dim_model, num_types)
bmt.init_parameters(self.dense)
def forward(self, *args, **kwargs):
pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output
logits = self.dense(pooler_output)
return logits
def get_tokenizer(args):
tokenizer = BertTokenizer.from_pretrained(args.model_config)
return tokenizer
def get_model(args):
num_types = {
"BoolQ" : 2,
"CB" : 3,
"COPA" : 1,
"RTE" : 2,
"WiC" : 2,
}
model = BertModel(args, num_types[args.dataset_name])
return model
def get_optimizer(args, model):
optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=args.weight_decay)
return optimizer
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters * args.epochs
if args.lr_decay_style == "noam":
lr_scheduler = bmt.lr_scheduler.Noam(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
elif args.lr_decay_style == "constant":
lr_scheduler = bmt.lr_scheduler.NoDecay(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = -1,
num_iter = args.start_step)
elif args.lr_decay_style == "linear":
lr_scheduler = bmt.lr_scheduler.Linear(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
elif args.lr_decay_style == "exponential":
lr_scheduler = bmt.lr_scheduler.Exponential(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
elif args.lr_decay_style == "cosine":
lr_scheduler = bmt.lr_scheduler.Cosine(optimizer,
start_lr = args.lr,
warmup_iter = args.warmup_iters,
end_iter = args.lr_decay_iters,
num_iter = args.start_step)
else:
raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.")
return lr_scheduler
def setup_model_and_optimizer(args):
# get the tokenizer
tokenizer = get_tokenizer(args)
# get the model
model = get_model(args)
bmt.synchronize()
# get the optimizer and lr_scheduler
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
# get the memory usage
bmt.print_rank("Model mem\n", torch.cuda.memory_summary())
bmt.synchronize()
return tokenizer, model, optimizer, lr_scheduler
def initialize():
# get arguments
args = get_args()
# init bmt
bmt.init_distributed(seed = args.seed)
# init save folder
if args.save != None:
os.makedirs(args.save, exist_ok=True)
return args
def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size):
splits = ['train', 'dev', 'test']
dataset = {}
for split in splits:
dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length)
return dataset
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
optim_manager.add_optimizer(optimizer, lr_scheduler)
print_inspect(model, '*')
for epoch in range(12):
dataloader = {
"train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True),
"dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False),
}
model.train()
for it, data in enumerate(dataloader['train']):
if args.dataset_name == 'COPA':
input_ids0 = data["input_ids0"]
attention_mask0 = data["attention_mask0"]
token_type_ids0 = data["token_type_ids0"]
input_ids1 = data["input_ids1"]
attention_mask1 = data["attention_mask1"]
token_type_ids1 = data["token_type_ids1"]
labels = data["labels"]
else:
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = data["labels"]
torch.cuda.synchronize()
st_time = time.time()
if args.dataset_name == 'COPA':
logits = torch.cat([
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
], dim=1)
else:
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
global_loss = bmt.sum_loss(loss).item()
optim_manager.zero_grad()
optim_manager.backward(loss)
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type = 2)
optim_manager.step()
torch.cuda.synchronize()
elapsed_time = time.time() - st_time
bmt.print_rank(
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format(
epoch,
it,
len(dataloader["train"]),
global_loss,
lr_scheduler.current_lr,
int(optim_manager.loss_scale),
grad_norm,
elapsed_time,
)
)
model.eval()
with torch.no_grad():
for split in ['dev']:
pd = []
gt = []
for it, data in enumerate(dataloader[split]):
if args.dataset_name == 'COPA':
input_ids0 = data["input_ids0"]
attention_mask0 = data["attention_mask0"]
token_type_ids0 = data["token_type_ids0"]
input_ids1 = data["input_ids1"]
attention_mask1 = data["attention_mask1"]
token_type_ids1 = data["token_type_ids1"]
labels = data["labels"]
logits = torch.cat([
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
], dim=1)
else:
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = data["labels"]
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
logits = logits.argmax(dim=-1)
pd.extend(logits.cpu().tolist())
gt.extend(labels.cpu().tolist())
bmt.print_rank(
"{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format(
split,
epoch,
it,
len(dataloader[split]),
loss,
)
)
pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist()
gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist()
bmt.print_rank(f"{split} epoch {epoch}:")
if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]:
acc = accuracy_score(gt, pd)
bmt.print_rank(f"accuracy: {acc*100:.2f}")
if args.dataset_name in ["CB"]:
rcl = f1_score(gt, pd, average="macro")
f1 = recall_score(gt, pd, average="macro")
bmt.print_rank(f"recall: {rcl*100:.2f}")
bmt.print_rank(f"Average F1: {f1*100:.2f}")
def main():
args = initialize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
dataset = prepare_dataset(
args,
tokenizer,
f"{args.base_path}/down_data/superglue/",
args.dataset_name,
bmt.rank(), bmt.world_size(),
)
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
if __name__ == "__main__":
main()

View File

@ -21,6 +21,7 @@ from .delta_models.adapter import AdapterModel
from .delta_models.prefix import PrefixModel
from .delta_models.soft_prompt import SoftPromptModel
from .delta_models.low_rank_adapter import LowRankAdapterModel
from .delta_models.parallel_adapter import ParallelAdapterModel
from .utils.visualization import Visualization

View File

@ -5,6 +5,7 @@ from multiprocessing.sharedctypes import Value
import os
from turtle import back
from opendelta.delta_configs import BaseDeltaConfig
from opendelta.utils.inspect import inspect_module_statistics
from opendelta.utils.model_md5 import gen_model_hash
from opendelta.utils.signature import get_arg_names, signature
from typing import Optional, Union
@ -27,6 +28,7 @@ from opendelta.utils.cuda import move_dict_to_cuda
import sys
from opendelta.utils.data_parallel import caller_map
from opendelta.utils.backend import BackendMapping
logger = logging.get_logger(__name__)
def is_leaf_module(module):
@ -94,6 +96,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
config_class = BaseDeltaConfig
default_unfrozen_modules = ["deltas"]
_need_pseudo_data = True
_supported_backends = ['hf']
def __init__(self,
backbone_model: nn.Module,
modified_modules: Optional[List[str]] = None,
@ -101,7 +104,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
unfrozen_modules: Optional[List[str]] = None,
interactive_modify: Optional[Union[bool, int]] = False,
common_structure: Optional[bool] = False,
framework_type: Optional[str]= "hf", # select from ["hf", "bmt"]
backend: Optional[str]= "hf", # select from ["hf", "bmt"]
):
nn.Module.__init__(self)
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
@ -139,7 +142,10 @@ class DeltaBase(nn.Module, SaveLoadMixin):
self.unfrozen_modules = self.default_unfrozen_modules
if self.common_structure and self.structure_mapping is None:
raise RuntimeError("Using common structure but the structure mapping is None")
self.framework_type = framework_type
if backend not in self._supported_backends:
raise RuntimeError("Currently, backend `{}` is not supported for `{}`".format(backend, self.__class__.__name__))
self.backend = backend
self.backend_mapping = BackendMapping(backend)
def forward(self, *args, **kwargs) -> RuntimeError:
r"""
@ -371,10 +377,11 @@ class DeltaBase(nn.Module, SaveLoadMixin):
_auto_dummy_fail = False
try:
module(**dummy_inputs)
except:
except Exception as e:
_auto_dummy_fail = True
if _auto_dummy_fail:
raise AttributeError(f"\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
if _auto_dummy_fail and _auto_dummy:
raise AttributeError(f"str({e})\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
@ -684,65 +691,16 @@ class DeltaBase(nn.Module, SaveLoadMixin):
from opendelta import Visualization
Visualization(module).structure_graph()
self.get_statistics(module)
self.stat = inspect_module_statistics(module, verbose=False)
if trainable_ratio:
logger.info("Trainable Ratio: {:2f}%".format(self.stat['trainable_ratio']*100))
logger.info("Trainable Ratio: {}/{}={:.6f}%".format(self.stat['trainable_parameters'], self.stat['total_parameters'], self.stat['trainable_ratio']*100))
if delta_ratio:
logger.info("Delta Parameter Ratio: {:2f}%".format(self.stat['delta_ratio']*100))
logger.info("Delta Parameter Ratio: {}/{}={:.6f}%".format(self.stat['delta_parameters'], self.stat['total_parameters'],self.stat['delta_ratio']*100))
if cuda_memory:
logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem']))
def get_statistics(self, module=None):
r"""Get the statistics of the parameters in the delta modules.
Args:
module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
Returns:
:obj:`dict`: The statistics of the parameters in the delta modules.
"""
if module is None:
module = self.backbone_model
self.stat = {}
n_trainable = self.num_trainable_parameters(module)
n_total = self.num_total_parameters(module)
self.stat['trainable_ratio'] = n_trainable/n_total
n_delta = self.num_delta_parameters(module)
n_total = self.num_total_parameters(module)
self.stat['delta_ratio'] = n_delta/n_total
cudamem = 0
maxcudamem = 0
for device_id in range(torch.cuda.device_count()):
cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
self.stat['cudamem'] = cudamem
self.stat['maxcudamem'] = maxcudamem
def num_delta_parameters(self, module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
compute the trainable rate.
Args:
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
Returns:
:obj:`List[nn.Parameter]`
"""
if module is None:
module = self.backbone_model
pnum_tot = 0
for param in module.parameters():
if hasattr(param, "_is_delta"):
pnum_tot += param.numel()
return pnum_tot
# Two functions for plug and remove the delta model.
def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True):

View File

@ -60,13 +60,14 @@ class AdapterLayer(nn.Module, InterFaceMixin):
def get_layer_count(cls):
return cls.layer_count
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None):
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None, backend="hf"):
super().__init__()
InterFaceMixin.__init__(self)
self.bottleneck_dim = bottleneck_dim
self.init_device = device
self.instantiated = False
self.non_linearity = non_linearity
self.backend=backend
self.layer_id = AdapterLayer.get_layer_count()
AdapterLayer.count_layer()
@ -79,14 +80,16 @@ class AdapterLayer(nn.Module, InterFaceMixin):
else:
return self.init_device
def instantiate(self, hidden_dim):
def instantiate(self, hiddens):
self.hidden_dim = hiddens.shape[-1]
self.hidden_dtype = hiddens.dtype
self.modulelist = nn.Sequential()
self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.init_device))
self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.init_device, dtype=self.hidden_dtype))
# select non-linearity
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device))
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device, dtype=self.hidden_dtype))
# TODO:
# If we want to have a layer norm on output, we apply it later after a separate residual connection
@ -97,11 +100,9 @@ class AdapterLayer(nn.Module, InterFaceMixin):
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
try:
if self.backend == 'bmt':
import bmtrain as bmt
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
except:
pass
def _init_weight(self, module):
if isinstance(module, nn.Linear):
@ -123,19 +124,29 @@ class AdapterLayer(nn.Module, InterFaceMixin):
raise TypeError
hiddens = self._transpose(hiddens)
hiddens = self._convert_data_type(hiddens)
# if self.backend == 'hf':
# hiddens = self._convert_data_type(hiddens)
# elif self.backend == 'bmt': # if bmt, left the convertion to bmt
# pass
if not self.instantiated:
self.hidden_dim = hiddens.shape[-1]
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
# self.hidden_dim = hiddens.shape[-1]
# logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hiddens=hiddens)
# from IPython import embed; embed(header="14135315")
adapter_output = self.modulelist(hiddens)
modified_output = adapter_output + hiddens # TODO option: disable residual_connection
modified_output = self._reverse_transpose(modified_output)
modified_output = self._reverse_data_type(modified_output)
# if self.backend == 'hf':
# # print("!"*100)
# modified_output = self._reverse_data_type(modified_output)
# elif self.backend == 'bmt': # if bmt, left the convertion to bmt
# print("!"*100)
# pass
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
@ -189,20 +200,24 @@ class AdapterModel(DeltaBase):
modified_modules (:obj:`List[str]`): modules to add adapter after them.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain.
"""
config_class = AdapterConfig
delta_type = "adapter"
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
bottleneck_dim: Optional[int]=24,
non_linearity: Optional[str]='gelu_new',
modified_modules: Optional[bool] = None,
modified_modules: Optional[List[str]] = None,
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[bool] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
backend: Optional[str] = 'hf',
):
DeltaBase.__init__(self,
backbone_model,
@ -211,6 +226,7 @@ class AdapterModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@ -231,6 +247,6 @@ class AdapterModel(DeltaBase):
def new_module_like(self, module):
module_device = get_device(module)
adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer

View File

@ -75,16 +75,6 @@ class BiasLayer(nn.Module):
raise TypeError
return output
framework_map = {}
framework_map['hf'] = {
"linear": nn.Linear,
"layer_norm": nn.LayerNorm,
}
framework_map['bmt'] = {
"linear": model_center.layer.Linear,
"layer_norm", model_center.layer.LayerNorm,
}
class BitFitModel(DeltaBase):
@ -124,6 +114,7 @@ class BitFitModel(DeltaBase):
config_class = BitFitConfig
delta_type = "bitfit"
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
_supported_backends = ['hf']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@ -132,7 +123,7 @@ class BitFitModel(DeltaBase):
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
framework_type: Optional[str] = "hf",
backend: Optional[str] = "hf",
):
DeltaBase.__init__(self,
backbone_model,
@ -141,7 +132,7 @@ class BitFitModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
framework_type=framework_type,
backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@ -153,6 +144,8 @@ class BitFitModel(DeltaBase):
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules)
def update_module(self, module: nn.Module, key: str):
@ -167,7 +160,8 @@ class BitFitModel(DeltaBase):
# if it is a leaf module, add bias to it regardless of its type.
# if self.check_linear(module):
# self.add_bias_to_linear(module)
if self.check_linear(module) or self.check_layernorm(module, nn.LayerNorm):
if self.backend_mapping.check_type(module, 'linear') or \
self.backend_mapping.check_type(module, 'layer_norm'):
self.add_bias_to_modules_have_bias_or_known_type(module)
else:
# for example, layer_norms, lm_heads.
@ -202,48 +196,27 @@ class BitFitModel(DeltaBase):
c.bias.requires_grad = True
self.delta_params.append(c.bias)
else:
if self.check_linear(c) or isinstance(c): # todo: bmt layerNorm
if self.backend_mapping.check_type(c, 'linear') or \
self.backend_mapping.check_type(c, 'layer_norm'):
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
self._reset_bias_parameters(c) #?
try:
self._reset_bias_parameters(c)
if self.backend == 'bmt':
import bmtrain as bmt
bias = bmt.BMTrainModelWrapper(bias)
except:
pass
c.register_parameter('bias', bias)
self.delta_params.append(bias)
def add_bias_to_others(self, c): # todo: bmtrain?
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c))
def add_bias_to_others(self, c):
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain?
if self.backend == 'bmt':
import bmtrain as bmt
new_bias = bmt.BMTrainModelWrapper(new_bias)
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)
def check_linear(self, m):
if isinstance(m, nn.Linear):
return True
else:
try:
from model_center.layer import Linear
if isinstance(m, Linear):
return True
except:
pass
return False
def check_layernorm(self, m):
if isinstance(m, nn.LayerNorm):
return True
else:
try:
from model_center.layer import LayerNorm
if isinstance(m, LayerNorm):
return True
except:
pass
return False
@staticmethod
def _reset_bias_parameters(linear_module):
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)

View File

@ -36,6 +36,7 @@ class HyperComplexAdapterLayer(nn.Module):
device=None,
use_bias_up_sampler=True,
use_bias_down_sampler=True,
backend = 'hf',
):
super().__init__()
self.reduction_factor = reduction_factor
@ -55,14 +56,17 @@ class HyperComplexAdapterLayer(nn.Module):
self.use_bias_up_sampler=use_bias_up_sampler
self.use_bias_down_sampler=use_bias_down_sampler
self.device = device
self.backend = backend
self.instantiated = False
def instantiate(self, hidden_dim):
self.down_sample_size = hidden_dim // self.reduction_factor
def instantiate(self, hiddens):
self.hidden_dim = hiddens.shape[-1]
self.hidden_dtype = hiddens.dtype
self.down_sample_size = self.hidden_dim // self.reduction_factor
self.activation = Activations(self.non_linearity.lower()).to(self.device)
self.down_sampler = PHMLinear(in_features=hidden_dim,
self.down_sampler = PHMLinear(in_features=self.hidden_dim,
out_features=self.down_sample_size,
bias=self.use_bias_down_sampler,
c_init=self.phm_c_init,
@ -76,9 +80,10 @@ class HyperComplexAdapterLayer(nn.Module):
factorized_phm_rule=self.factorized_phm_rule,
phm_rank=self.phm_rank,
phm_init_range=self.phm_init_range,
kronecker_prod=self.kronecker_prod).to(self.device)
kronecker_prod=self.kronecker_prod,
dtype = self.hidden_dtype).to(self.device)
self.up_sampler = PHMLinear(in_features=self.down_sample_size,
out_features=hidden_dim,
out_features=self.hidden_dim,
bias=self.use_bias_up_sampler,
c_init=self.phm_c_init,
phm_dim=self.hypercomplex_division,
@ -91,15 +96,14 @@ class HyperComplexAdapterLayer(nn.Module):
factorized_phm_rule=self.factorized_phm_rule,
phm_rank=self.phm_rank,
phm_init_range=self.phm_init_range,
kronecker_prod=self.kronecker_prod).to(self.device)
kronecker_prod=self.kronecker_prod,
dtype = self.hidden_dtype).to(self.device)
self.instantiated = True
try:
if self.backend == "bmt":
import bmtrain as bmt
self.activation = bmt.BMTrainModelWrapper(self.activation)
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
except:
pass
def post_forward(self, output):
@ -116,9 +120,7 @@ class HyperComplexAdapterLayer(nn.Module):
raise TypeError
if not self.instantiated:
self.hidden_dim = hiddens.shape[-1]
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
self.instantiate(hiddens=hiddens)
z = self.down_sampler(hiddens)
@ -193,6 +195,7 @@ class CompacterModel(DeltaBase):
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): whether using name-based addressing with a common structure mapping.
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
reduction_factor (:obj:`int`, *optional*, default to ``16``): bottleneck_dim = hidden_dim//reduction_factor
non_linearity (:obj:`str`, *optional*, default to ``"gelu_new"``): The non linearity activation used in between the down
projecter and the up projecter.
@ -218,6 +221,7 @@ class CompacterModel(DeltaBase):
config_class = CompacterConfig
delta_type = "compacter"
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model,
@ -226,6 +230,7 @@ class CompacterModel(DeltaBase):
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
backend: Optional[str] = 'hf',
reduction_factor=16,
non_linearity="gelu_new",
phm_c_init="normal",
@ -288,22 +293,6 @@ class CompacterModel(DeltaBase):
def new_module_like(self, module):
module_device = get_device(module)
adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor,
non_linearity=self.non_linearity,
phm_c_init=self.phm_c_init,
hypercomplex_division=self.hypercomplex_division,
learn_phm=self.learn_phm,
hypercomplex_nonlinearity=self.hypercomplex_nonlinearity,
shared_phm_rule=self.shared_phm_rule,
factorized_phm=self.factorized_phm,
shared_W_phm=self.shared_W_phm,
factorized_phm_rule=self.factorized_phm_rule,
phm_rank=self.phm_rank,
phm_init_range=self.phm_init_range,
kronecker_prod=self.kronecker_prod,
use_bias_up_sampler=self.use_bias_up_sampler,
use_bias_down_sampler=self.use_bias_down_sampler,
device=module_device
)
adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor, non_linearity=self.non_linearity, phm_c_init=self.phm_c_init, hypercomplex_division=self.hypercomplex_division, learn_phm=self.learn_phm, hypercomplex_nonlinearity=self.hypercomplex_nonlinearity, shared_phm_rule=self.shared_phm_rule, factorized_phm=self.factorized_phm, shared_W_phm=self.shared_W_phm, factorized_phm_rule=self.factorized_phm_rule, phm_rank=self.phm_rank, phm_init_range=self.phm_init_range, kronecker_prod=self.kronecker_prod, use_bias_up_sampler=self.use_bias_up_sampler, use_bias_down_sampler=self.use_bias_down_sampler, device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer

View File

@ -84,7 +84,8 @@ class PHMLinear(torch.nn.Module):
factorized_phm_rule=False,
phm_rank = 1,
phm_init_range=0.0001,
kronecker_prod=False) -> None:
kronecker_prod=False,
dtype=torch.float) -> None:
super(PHMLinear, self).__init__()
assert w_init in ["phm", "glorot-normal", "glorot-uniform", "normal"]
assert c_init in ["normal", "uniform"]
@ -104,12 +105,12 @@ class PHMLinear(torch.nn.Module):
self.factorized_phm_rule = factorized_phm_rule
if not self.shared_phm_rule:
if self.factorized_phm_rule:
self.phm_rule_left = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, 1),
self.phm_rule_left = nn.Parameter(torch.empty((phm_dim, phm_dim, 1), dtype=dtype),
requires_grad=learn_phm)
self.phm_rule_right = nn.Parameter(torch.FloatTensor(phm_dim, 1, phm_dim),
self.phm_rule_right = nn.Parameter(torch.empty((phm_dim, 1, phm_dim), dtype=dtype),
requires_grad=learn_phm)
else:
self.phm_rule = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, phm_dim),
self.phm_rule = nn.Parameter(torch.empty((phm_dim, phm_dim, phm_dim), dtype=dtype),
requires_grad=learn_phm)
self.bias_flag = bias
self.w_init = w_init
@ -118,15 +119,15 @@ class PHMLinear(torch.nn.Module):
self.factorized_phm = factorized_phm
if not self.shared_W_phm:
if self.factorized_phm:
self.W_left = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self.phm_rank)),
self.W_left = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self.phm_rank), dtype=dtype),
requires_grad=True)
self.W_right = nn.Parameter(torch.Tensor(size=(phm_dim, self.phm_rank, self._out_feats_per_axis)),
self.W_right = nn.Parameter(torch.empty((phm_dim, self.phm_rank, self._out_feats_per_axis), dtype=dtype),
requires_grad=True)
else:
self.W = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self._out_feats_per_axis)),
self.W = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self._out_feats_per_axis), dtype=dtype),
requires_grad=True)
if self.bias_flag:
self.b = nn.Parameter(torch.Tensor(out_features))
self.b = nn.Parameter(torch.empty(out_features, dtype=dtype), requires_grad=True)
else:
self.register_parameter("b", None)
self.reset_parameters()

View File

@ -6,17 +6,17 @@ from opendelta.delta_models.layers.init import glorot_uniform, glorot_normal
class LowRankLinear(torch.nn.Module):
def __init__(self, input_dim: int, output_dim: int, rank: int = 1,
bias: bool = True, w_init: str = "glorot-uniform"):
bias: bool = True, w_init: str = "glorot-uniform", dtype=torch.float):
super(LowRankLinear, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.rank = rank
self.bias = bias
self.w_init = w_init
self.W_left = nn.Parameter(torch.Tensor(size=(input_dim, rank)), requires_grad=True)
self.W_right = nn.Parameter(torch.Tensor(size=(rank, output_dim)), requires_grad=True)
self.W_left = nn.Parameter(torch.empty((input_dim, rank), dtype=dtype),requires_grad=True)
self.W_right = nn.Parameter(torch.empty((rank, output_dim), dtype=dtype), requires_grad=True)
if bias:
self.b = nn.Parameter(torch.Tensor(output_dim))
self.b = nn.Parameter(torch.empty(output_dim, dtype=dtype))
self.reset_parameters()
def reset_parameters(self):

View File

@ -97,12 +97,14 @@ class LoraModel(DeltaBase):
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
"""
config_class = LoraConfig
delta_type = "lora"
default_modified_modules = ['attn@.q@', 'attn@.v@']
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,
@ -114,6 +116,7 @@ class LoraModel(DeltaBase):
exclude_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
backend: Optional[str] = "hf",
):
DeltaBase.__init__(self,
backbone_model,
@ -121,6 +124,7 @@ class LoraModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@ -151,10 +155,9 @@ class LoraModel(DeltaBase):
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
try:
if self.backend == "bmt":
import bmtrain as bmt
new_module = bmt.BMTrainModelWrapper(new_module)
except:
pass
self.delta_modules.append(new_module)
return new_module

View File

@ -47,7 +47,8 @@ class LowRankAdapter(nn.Module):
non_linearity="gelu_new",
low_rank_w_init="glorot-uniform",
low_rank_rank=1,
device=None):
device=None,
backend='hf'):
super().__init__()
self.reduction_factor = reduction_factor
self.non_linearity = non_linearity
@ -55,27 +56,31 @@ class LowRankAdapter(nn.Module):
self.low_rank_rank = low_rank_rank
self.device = device
self.instantiated = False
self.backend=backend
def instantiate(self, hidden_dim):
def instantiate(self, hiddens):
self.hidden_dim = hiddens.shape[-1]
self.hidden_dtype = hiddens.dtype
self.down_sample_size = hidden_dim // self.reduction_factor
self.down_sample_size = self.hidden_dim // self.reduction_factor
self.activation = Activations(self.non_linearity.lower()).to(self.device)
self.down_sampler = LowRankLinear(hidden_dim, self.down_sample_size,
self.down_sampler = LowRankLinear(self.hidden_dim, self.down_sample_size,
w_init=self.low_rank_w_init,
rank=self.low_rank_rank).to(self.device)
self.up_sampler = LowRankLinear(self.down_sample_size, hidden_dim,
rank=self.low_rank_rank,
dtype=self.hidden_dtype).to(self.device)
self.up_sampler = LowRankLinear(self.down_sample_size, self.hidden_dim,
w_init=self.low_rank_w_init,
rank=self.low_rank_rank).to(self.device)
rank=self.low_rank_rank,
dtype=self.hidden_dtype).to(self.device)
self.instantiated = True
try:
if self.backend == 'bmt':
import bmtrain as bmt
self.activation = bmt.BMTrainModelWrapper(self.activation)
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
except:
pass
def post_forward(self, output):
r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter,
@ -91,9 +96,7 @@ class LowRankAdapter(nn.Module):
raise TypeError
if not self.instantiated:
self.hidden_dim = hiddens.shape[-1]
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
self.instantiate(hiddens = hiddens)
z = self.down_sampler(hiddens)
z = self.activation(z)
@ -154,6 +157,7 @@ class LowRankAdapterModel(DeltaBase):
config_class = LowRankAdapterConfig
delta_type = "low_rank_adapter"
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
@ -166,6 +170,7 @@ class LowRankAdapterModel(DeltaBase):
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
backend: Optional[str] = 'hf',
):
DeltaBase.__init__(self,
backbone_model,
@ -174,6 +179,7 @@ class LowRankAdapterModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@ -209,6 +215,6 @@ class LowRankAdapterModel(DeltaBase):
non_linearity = self.non_linearity,
low_rank_w_init = self.low_rank_w_init,
low_rank_rank = self.low_rank_rank,
device=module_device)
device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer

View File

@ -25,30 +25,36 @@ class ParallelAdapterLayer(nn.Module):
def get_layer_count(cls):
return cls.layer_count
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None):
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None, backend='hf'):
super().__init__()
self.bottleneck_dim = bottleneck_dim
self.device = device
self.instantiated = False
self.non_linearity = non_linearity
self.scaled = scaled
self.backend = backend
self.layer_id = ParallelAdapterLayer.get_layer_count()
ParallelAdapterLayer.count_layer()
def instantiate(self, hidden_dim):
def instantiate(self, hiddens):
self.hidden_dim = hiddens.shape[-1]
self.hidden_dtype = hiddens.dtype
self.modulelist = nn.Sequential()
self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.device))
self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.device, dtype=self.hidden_dtype))
# select non-linearity
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device, dtype=self.hidden_dtype))
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
if self.backend == 'bmt':
import bmtrain as bmt
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
def _init_weight(self, module):
if isinstance(module, nn.Linear):
@ -71,9 +77,8 @@ class ParallelAdapterLayer(nn.Module):
if not self.instantiated:
self.hidden_dim = hiddens.shape[-1]
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
# logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hiddens = hiddens)
self.adapter_output = self.modulelist(hiddens) * self.scaled
@ -141,12 +146,14 @@ class ParallelAdapterModel(DeltaBase):
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
"""
config_class = ParallelAdapterConfig
delta_type = "parallel_adapter"
default_modified_modules = ["attn@", "attn@", "ff@.w1@", "ff@.w2@"]
# default_modified_modules = ["attn", "attn", "ff.w1", "ff.w2"]
_supported_backends = ['hf', 'bmt']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,
@ -156,7 +163,8 @@ class ParallelAdapterModel(DeltaBase):
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[bool] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
interactive_modify: Optional[Union[bool, int]] = False,
backend: Optional[str] = "hf",
):
DeltaBase.__init__(self,
backbone_model,
@ -165,6 +173,7 @@ class ParallelAdapterModel(DeltaBase):
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
backend=backend,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
@ -193,7 +202,7 @@ class ParallelAdapterModel(DeltaBase):
def new_module_like(self, module):
module_device = get_device(module)
adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend)
self.delta_modules.append(adapterlayer)
return adapterlayer

View File

@ -516,6 +516,7 @@ class PrefixModel(DeltaBase):
config_class = PrefixConfig
delta_type = "prefix"
default_modified_modules = ['attn@']
_supported_backends = ['hf']
_need_pseudo_data = True
def __init__(self,
backbone_model: nn.Module,

View File

@ -161,6 +161,7 @@ class SoftPromptModel(DeltaBase):
config_class = SoftPromptConfig
delta_type = "soft_prompt"
default_modified_modules = ["root"] # not used
_supported_backends = ['hf'] #'bmt']
_need_pseudo_data = False
def __init__(self,
backbone_model: nn.Module,

110
opendelta/utils/backend.py Normal file
View File

@ -0,0 +1,110 @@
import importlib
class BackendMapping:
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args:
- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""
def __init__(self, backend):
self.backend = backend
assert backend in ['hf', 'bmt'], "Backend should be one of 'hf', 'bmt'. "
if backend == 'hf':
self.backend_mapping = {
"linear": "torch.nn.Linear",
"layer_norm": "torch.nn.LayerNorm",
"module": "torch.nn.Module",
"parameter": "torch.nn.Parameter"
}
elif backend == 'bmt':
self.backend_mapping = {
"linear": "model_center.layer.Linear",
"layer_norm": "model_center.layer.LayerNorm",
"module": "bmtrain.layer.DistributedModule",
"parameter": "bmtrain.nn.DistributedParameter"
}
self.registered = {}
def load(self, model_type):
if model_type not in self.registered:
splited = self.backend_mapping[model_type].split(".")
module_name, class_name = ".".join(splited[:-1]), splited[-1]
module = importlib.import_module(module_name)
the_class = getattr(module, class_name)
self.registered[model_type] = the_class
return self.registered[model_type]
def check_type(self, module, expect_type):
the_class = self.load(expect_type)
if isinstance(module, the_class):
return True
else:
return False
# def keys(self):
# mapping_keys = [
# self._load_attr_from_module(key, name)
# for key, name in self._config_mapping.items()
# if key in self._model_mapping.keys()
# ]
# return mapping_keys + list(self._extra_content.keys())
# def get(self, key, default):
# try:
# return self.__getitem__(key)
# except KeyError:
# return default
# def __bool__(self):
# return bool(self.keys())
# def values(self):
# mapping_values = [
# self._load_attr_from_module(key, name)
# for key, name in self._model_mapping.items()
# if key in self._config_mapping.keys()
# ]
# return mapping_values + list(self._extra_content.values())
# def items(self):
# mapping_items = [
# (
# self._load_attr_from_module(key, self._config_mapping[key]),
# self._load_attr_from_module(key, self._model_mapping[key]),
# )
# for key in self._model_mapping.keys()
# if key in self._config_mapping.keys()
# ]
# return mapping_items + list(self._extra_content.items())
# def __iter__(self):
# return iter(self.keys())
# def __contains__(self, item):
# if item in self._extra_content:
# return True
# if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
# return False
# model_type = self._reverse_config_mapping[item.__name__]
# return model_type in self._model_mapping
# def register(self, key, value):
# """
# Register a new model in this mapping.
# """
# if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
# model_type = self._reverse_config_mapping[key.__name__]
# if model_type in self._model_mapping.keys():
# raise ValueError(f"'{key}' is already used by a Transformers model.")
# self._extra_content[key] = value

112
opendelta/utils/inspect.py Normal file
View File

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
from typing import Optional
import opendelta.utils.logging as logging
logger = logging.get_logger(__name__)
def inspect_module_statistics(module: Optional[nn.Module]=None, verbose=True):
r"""Get the statistics of the parameters in the delta modules.
Args:
module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
Returns:
:obj:`dict`: The statistics of the parameters in the delta modules.
"""
stat = {}
n_trainable = num_trainable_parameters(module)
n_total = num_total_parameters(module)
stat['total_parameters'] = n_total
stat['trainable_parameters'] = n_trainable
stat['trainable_ratio'] = n_trainable/n_total
n_delta = num_delta_parameters(module)
n_total = num_total_parameters(module)
stat['delta_parameters'] = n_delta
stat['delta_ratio'] = n_delta/n_total
cudamem = 0
maxcudamem = 0
for device_id in range(torch.cuda.device_count()):
cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
stat['cudamem'] = cudamem
stat['maxcudamem'] = maxcudamem
if verbose:
logger.info(stat)
return stat
def num_trainable_parameters(module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
compute the trainable rate.
Args:
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
Returns:
:obj:`List[nn.Parameter]`
"""
pnum_tot = 0
for param in module.parameters():
if param.requires_grad:
pnum_tot += param.numel()
return pnum_tot
def num_total_parameters(module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
compute the trainable rate.
Args:
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
Returns:
:obj:`List[nn.Parameter]`
"""
pnum_tot = 0
for param in module.parameters():
pnum_tot += param.numel()
return pnum_tot
def num_delta_parameters(module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
compute the trainable rate.
Args:
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
Returns:
:obj:`List[nn.Parameter]`
"""
pnum_tot = 0
for param in module.parameters():
if hasattr(param, "_is_delta"):
pnum_tot += param.numel()
return pnum_tot
def inspect_optimizer_statistics(optimizer, verbose=True):
stats = {}
for id, param_group in enumerate(optimizer.param_groups):
stat = {}
fine_grain_info = [(p.numel(), p.requires_grad) for p in param_group['params']]
stat['total_parameters'] = sum(n for n, r in fine_grain_info)
stat['trainable_parameters'] = sum(n for n, r in fine_grain_info if r)
stat['trainable_ratio'] = "{:.6f}%".format(stat['trainable_parameters']/stat['total_parameters']*100)
for key in param_group:
if key != 'params':
stat[key] = param_group[key]
stats[f'param_group_{id}'] = stat
if verbose:
logger.info(f"optimizer info: {stats}")
return stat

View File

@ -31,7 +31,7 @@ def get_requirements():
with open('README.md', 'r') as f:
setuptools.setup(
name = 'opendelta',
version = "0.3.1",
version = "0.3.2",
description = "An open source framework for delta learning (parameter efficient learning).",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",