support bmtrain
This commit is contained in:
parent
e0de6b02ad
commit
ce7f992864
|
@ -54,8 +54,13 @@ t.sh
|
||||||
**/delta_checkpoints/
|
**/delta_checkpoints/
|
||||||
**/outputs/
|
**/outputs/
|
||||||
|
|
||||||
|
dist/*
|
||||||
|
|
||||||
**/unittest/**
|
**/unittest/**
|
||||||
!unittest/**.py
|
!unittest/**.py
|
||||||
!unittest/**.sh
|
!unittest/**.sh
|
||||||
|
|
||||||
|
**/tutorial/**
|
||||||
|
!tutorial/**.py
|
||||||
|
!tutorial/**.sh
|
||||||
|
!tutorial/**.md
|
||||||
|
|
|
@ -31,8 +31,8 @@ copyright = '{}, {}, Licenced under the Apache License, Version 2.0'.format(date
|
||||||
|
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = '0.3.1'
|
release = '0.3.2'
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
|
|
||||||
html_theme = 'sphinx_rtd_theme'
|
html_theme = 'sphinx_rtd_theme'
|
||||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
|
|
||||||
(acceleration)=
|
(acceleration)=
|
||||||
# OpenDelta+
|
# OpenDelta+
|
||||||
<img src="../imgs/todo-icon.jpeg" height="30px"> We are working on testing and improving the functionality with work with other acceleration packages for model training and inference. For example, [deepspeed](https://github.com/microsoft/DeepSpeed), [BMInf](https://github.com/OpenBMB/BMInf).
|
|
||||||
|
|
||||||
Feel free to contact us via email (shengdinghu@gmail.com) if you have any suggestion.
|
## BMTrain
|
||||||
|
|
||||||
|
- [BMTrain](https://github.com/OpenBMB/BMTrain) is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
|
||||||
|
- [ModelCenter](https://github.com/OpenBMB/ModelCenter) implements pre-trained language models (PLMs) based on the backend OpenBMB/BMTrain. ModelCenter supports Efficient, Low-Resource, Extendable model usage and distributed training.
|
||||||
|
|
||||||
|
Now we have the LoraModel, AdapterModel, CompacterModel, ParallelAdapterModel, LowRankAdapterModel fully supported the distributed training with BMTrain and ModelCenter. Please try is out in
|
||||||
|
|
||||||
|
|
||||||
|
## Huggingface Accelerate
|
||||||
|
<img src="../imgs/todo-icon.jpeg" height="30px">
|
|
@ -1,5 +1,10 @@
|
||||||
# Update Logs and Known Issues
|
# Update Logs and Known Issues
|
||||||
|
|
||||||
|
## Version 0.3.2
|
||||||
|
- We support BMTrain to accelerate the training, and parallelize the training of models that are hard to fit in a single GPU. Check [tutorial/2_with_bmtrain.py](https://github.com/thunlp/OpenDelta/tree/main/examples/tutorial/2_with_bmtrain.py)
|
||||||
|
- We add a functionality to [inspect the optimizer](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py). The user can see the number of trainable parameters in the optimizer and verify that opendelta is being used correctly.
|
||||||
|
- We move the functions to inspect the delta models into [inspect.py](https://github.com/thunlp/OpenDelta/tree/main/opendelta/utils/inspect.py)
|
||||||
|
|
||||||
## Version 0.3.1
|
## Version 0.3.1
|
||||||
- We update [must_try.py](https://github.com/thunlp/OpenDelta/tree/main/examples/unittest/must_try.py) for a simple introduction of the core functionality of OpenDelta.
|
- We update [must_try.py](https://github.com/thunlp/OpenDelta/tree/main/examples/unittest/must_try.py) for a simple introduction of the core functionality of OpenDelta.
|
||||||
- Thanks to [Weilin Zhao](https://github.com/Achazwl) We merge a long-developed branch parallel_adapter into the main branch.
|
- Thanks to [Weilin Zhao](https://github.com/Achazwl) We merge a long-developed branch parallel_adapter into the main branch.
|
||||||
|
|
|
@ -1,50 +1,291 @@
|
||||||
import bmtrain as bmt
|
# adapted from https://github.com/OpenBMB/ModelCenter/blob/main/examples/bert/finetune_bert.py
|
||||||
import opendelta as od
|
|
||||||
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy
|
import numpy as np
|
||||||
import random
|
from sklearn.metrics import accuracy_score, recall_score, f1_score
|
||||||
|
|
||||||
def manual_seed(seed):
|
import bmtrain as bmt
|
||||||
torch.manual_seed(seed)
|
|
||||||
numpy.random.seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
|
|
||||||
from model_center.model import Bert, BertConfig
|
from model_center import get_args
|
||||||
bmt.init_distributed()
|
from model_center.model import Bert
|
||||||
config = BertConfig.from_pretrained("bert-base-uncased")
|
from model_center.tokenizer import BertTokenizer
|
||||||
config.dropout_p = 0
|
from model_center.dataset.bertdataset import DATASET
|
||||||
model = Bert.from_pretrained("bert-base-uncased", config)
|
from model_center.utils import print_inspect
|
||||||
|
from model_center.layer import Linear
|
||||||
|
from model_center.dataset import DistributedDataLoader
|
||||||
|
|
||||||
|
import opendelta as od
|
||||||
|
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
|
||||||
|
from opendelta.utils.inspect import inspect_optimizer_statistics
|
||||||
print("before modify")
|
print("before modify")
|
||||||
od.Visualization(model).structure_graph()
|
|
||||||
|
|
||||||
manual_seed(233)
|
class BertModel(torch.nn.Module):
|
||||||
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'])
|
def __init__(self, args, num_types):
|
||||||
# delta_model = AdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
|
super().__init__()
|
||||||
# delta_model = CompacterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
|
self.bert : Bert = Bert.from_pretrained(args.model_config)
|
||||||
# delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
|
dim_model = self.bert.input_embedding.dim_model
|
||||||
# delta_model = BitFitModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn', '[r](.*)layernorm(.*)'])
|
self.dense = Linear(dim_model, num_types)
|
||||||
|
bmt.init_parameters(self.dense)
|
||||||
|
|
||||||
# print(delta_model.delta_modules)
|
def forward(self, *args, **kwargs):
|
||||||
|
pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output
|
||||||
|
logits = self.dense(pooler_output)
|
||||||
|
return logits
|
||||||
|
|
||||||
print("after modify")
|
def get_tokenizer(args):
|
||||||
delta_model.log()
|
tokenizer = BertTokenizer.from_pretrained(args.model_config)
|
||||||
# This will visualize the backbone after modification and other information.
|
return tokenizer
|
||||||
|
|
||||||
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
|
def get_model(args):
|
||||||
print("after freeze")
|
num_types = {
|
||||||
delta_model.log()
|
"BoolQ" : 2,
|
||||||
# The set_state_dict=True will tell the method to change the state_dict of the backbone_model to maintaining only the trainable parts.
|
"CB" : 3,
|
||||||
|
"COPA" : 1,
|
||||||
|
"RTE" : 2,
|
||||||
|
"WiC" : 2,
|
||||||
|
}
|
||||||
|
model = BertModel(args, num_types[args.dataset_name])
|
||||||
|
od.Visualization(model).structure_graph()
|
||||||
|
|
||||||
manual_seed(233)
|
|
||||||
inp = torch.randint(0, 30000, (32, 128)).cuda()
|
|
||||||
length = torch.randint(0, 128, (32,)).cuda()
|
|
||||||
attention_mask = (torch.arange(inp.shape[1], device=inp.device)[None, :].repeat(inp.shape[0], 1) < length[:, None])
|
|
||||||
out = model(inp, attention_mask=attention_mask, output_logits=True).logits
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
if bmt.rank() == 0:
|
if args.delta_type == "lora":
|
||||||
torch.save(model.state_dict(), "test.pt")
|
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'], backend='bmt')
|
||||||
ckpt = torch.load("test.pt")
|
elif args.delta_type == "bitfit":
|
||||||
print(ckpt.keys())
|
delta_model = BitFitModel(backbone_model=model, modified_modules=['self_att', 'ffn', 'layernorm'], backend='bmt') #TODO: fix bug
|
||||||
|
elif args.delta_type == "adapter":
|
||||||
|
delta_model = AdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
|
||||||
|
elif args.delta_type == "compacter":
|
||||||
|
delta_model = CompacterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
|
||||||
|
elif args.delta_type == "low_rank_adapter":
|
||||||
|
delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['self_att', 'ffn'], backend='bmt')
|
||||||
|
elif args.delta_type == "parallel_adapter":
|
||||||
|
delta_model = ParallelAdapterModel(backbone_model=model, modified_modules=['self_att', 'self_att', 'ffn.ffn', 'ffn.ffn'], backend='bmt')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print("after modify")
|
||||||
|
delta_model.log()
|
||||||
|
# This will visualize the backbone after modification and other information.
|
||||||
|
|
||||||
|
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
|
||||||
|
print("after freeze")
|
||||||
|
delta_model.log()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=args.weight_decay)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
if args.lr_decay_iters is None:
|
||||||
|
args.lr_decay_iters = args.train_iters * args.epochs
|
||||||
|
if args.lr_decay_style == "noam":
|
||||||
|
lr_scheduler = bmt.lr_scheduler.Noam(optimizer,
|
||||||
|
start_lr = args.lr,
|
||||||
|
warmup_iter = args.warmup_iters,
|
||||||
|
end_iter = args.lr_decay_iters,
|
||||||
|
num_iter = args.start_step)
|
||||||
|
elif args.lr_decay_style == "constant":
|
||||||
|
lr_scheduler = bmt.lr_scheduler.NoDecay(optimizer,
|
||||||
|
start_lr = args.lr,
|
||||||
|
warmup_iter = args.warmup_iters,
|
||||||
|
end_iter = -1,
|
||||||
|
num_iter = args.start_step)
|
||||||
|
elif args.lr_decay_style == "linear":
|
||||||
|
lr_scheduler = bmt.lr_scheduler.Linear(optimizer,
|
||||||
|
start_lr = args.lr,
|
||||||
|
warmup_iter = args.warmup_iters,
|
||||||
|
end_iter = args.lr_decay_iters,
|
||||||
|
num_iter = args.start_step)
|
||||||
|
elif args.lr_decay_style == "exponential":
|
||||||
|
lr_scheduler = bmt.lr_scheduler.Exponential(optimizer,
|
||||||
|
start_lr = args.lr,
|
||||||
|
warmup_iter = args.warmup_iters,
|
||||||
|
end_iter = args.lr_decay_iters,
|
||||||
|
num_iter = args.start_step)
|
||||||
|
elif args.lr_decay_style == "cosine":
|
||||||
|
lr_scheduler = bmt.lr_scheduler.Cosine(optimizer,
|
||||||
|
start_lr = args.lr,
|
||||||
|
warmup_iter = args.warmup_iters,
|
||||||
|
end_iter = args.lr_decay_iters,
|
||||||
|
num_iter = args.start_step)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.")
|
||||||
|
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
# get the tokenizer
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
# get the model
|
||||||
|
model = get_model(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
# get the optimizer and lr_scheduler
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
|
||||||
|
inspect_optimizer_statistics(optimizer)
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
# get the memory usage
|
||||||
|
bmt.print_rank("Model mem\n", torch.cuda.memory_summary())
|
||||||
|
bmt.synchronize()
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
# get arguments
|
||||||
|
args = get_args()
|
||||||
|
# init bmt
|
||||||
|
bmt.init_distributed(seed = args.seed)
|
||||||
|
# init save folder
|
||||||
|
if args.save != None:
|
||||||
|
os.makedirs(args.save, exist_ok=True)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size):
|
||||||
|
splits = ['train', 'dev', 'test']
|
||||||
|
dataset = {}
|
||||||
|
for split in splits:
|
||||||
|
dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||||
|
|
||||||
|
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
# print_inspect(model, '*') # too much output
|
||||||
|
|
||||||
|
for epoch in range(12):
|
||||||
|
dataloader = {
|
||||||
|
"train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True),
|
||||||
|
"dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for it, data in enumerate(dataloader['train']):
|
||||||
|
if args.dataset_name == 'COPA':
|
||||||
|
input_ids0 = data["input_ids0"]
|
||||||
|
attention_mask0 = data["attention_mask0"]
|
||||||
|
token_type_ids0 = data["token_type_ids0"]
|
||||||
|
input_ids1 = data["input_ids1"]
|
||||||
|
attention_mask1 = data["attention_mask1"]
|
||||||
|
token_type_ids1 = data["token_type_ids1"]
|
||||||
|
labels = data["labels"]
|
||||||
|
else:
|
||||||
|
input_ids = data["input_ids"]
|
||||||
|
attention_mask = data["attention_mask"]
|
||||||
|
token_type_ids = data["token_type_ids"]
|
||||||
|
labels = data["labels"]
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
st_time = time.time()
|
||||||
|
|
||||||
|
if args.dataset_name == 'COPA':
|
||||||
|
logits = torch.cat([
|
||||||
|
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
|
||||||
|
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
|
||||||
|
], dim=1)
|
||||||
|
else:
|
||||||
|
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||||
|
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
|
|
||||||
|
global_loss = bmt.sum_loss(loss).item()
|
||||||
|
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
|
||||||
|
optim_manager.backward(loss)
|
||||||
|
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type = 2)
|
||||||
|
|
||||||
|
optim_manager.step()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed_time = time.time() - st_time
|
||||||
|
|
||||||
|
# from IPython import embed; embed(header="25252")
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format(
|
||||||
|
epoch,
|
||||||
|
it,
|
||||||
|
len(dataloader["train"]),
|
||||||
|
global_loss,
|
||||||
|
lr_scheduler.current_lr,
|
||||||
|
int(optim_manager.loss_scale),
|
||||||
|
grad_norm,
|
||||||
|
elapsed_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for split in ['dev']:
|
||||||
|
pd = []
|
||||||
|
gt = []
|
||||||
|
for it, data in enumerate(dataloader[split]):
|
||||||
|
if args.dataset_name == 'COPA':
|
||||||
|
input_ids0 = data["input_ids0"]
|
||||||
|
attention_mask0 = data["attention_mask0"]
|
||||||
|
token_type_ids0 = data["token_type_ids0"]
|
||||||
|
input_ids1 = data["input_ids1"]
|
||||||
|
attention_mask1 = data["attention_mask1"]
|
||||||
|
token_type_ids1 = data["token_type_ids1"]
|
||||||
|
labels = data["labels"]
|
||||||
|
logits = torch.cat([
|
||||||
|
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
|
||||||
|
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
|
||||||
|
], dim=1)
|
||||||
|
else:
|
||||||
|
input_ids = data["input_ids"]
|
||||||
|
attention_mask = data["attention_mask"]
|
||||||
|
token_type_ids = data["token_type_ids"]
|
||||||
|
labels = data["labels"]
|
||||||
|
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||||
|
|
||||||
|
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
|
logits = logits.argmax(dim=-1)
|
||||||
|
pd.extend(logits.cpu().tolist())
|
||||||
|
gt.extend(labels.cpu().tolist())
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format(
|
||||||
|
split,
|
||||||
|
epoch,
|
||||||
|
it,
|
||||||
|
len(dataloader[split]),
|
||||||
|
loss,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist()
|
||||||
|
gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist()
|
||||||
|
|
||||||
|
bmt.print_rank(f"{split} epoch {epoch}:")
|
||||||
|
if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]:
|
||||||
|
acc = accuracy_score(gt, pd)
|
||||||
|
bmt.print_rank(f"accuracy: {acc*100:.2f}")
|
||||||
|
if args.dataset_name in ["CB"]:
|
||||||
|
rcl = f1_score(gt, pd, average="macro")
|
||||||
|
f1 = recall_score(gt, pd, average="macro")
|
||||||
|
bmt.print_rank(f"recall: {rcl*100:.2f}")
|
||||||
|
bmt.print_rank(f"Average F1: {f1*100:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
dataset = prepare_dataset(
|
||||||
|
args,
|
||||||
|
tokenizer,
|
||||||
|
f"{args.base_path}/down_data/superglue/",
|
||||||
|
args.dataset_name,
|
||||||
|
bmt.rank(), bmt.world_size(),
|
||||||
|
)
|
||||||
|
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1 +1,37 @@
|
||||||
python3 -m torch.distributed.launch --master_addr localhost --master_port 34123 --nproc_per_node $1 --nnodes 1 --node_rank 0 2_with_bmtrain.py
|
#! /bin/bash
|
||||||
|
|
||||||
|
MASTER_ADDR=localhost
|
||||||
|
MASTER_PORT=12345
|
||||||
|
NNODES=1
|
||||||
|
NODE_RANK=0
|
||||||
|
GPUS_PER_NODE=4
|
||||||
|
|
||||||
|
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
|
||||||
|
--nnodes $NNODES \
|
||||||
|
--node_rank $NODE_RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT"
|
||||||
|
|
||||||
|
BASE_PATH="./"
|
||||||
|
VERSION="bert-large-cased"
|
||||||
|
DATASET="BoolQ" # You can try other dataset listed in https://github.com/OpenBMB/ModelCenter/tree/main/examples/bert
|
||||||
|
|
||||||
|
OPTS=""
|
||||||
|
OPTS+=" --model-config ${VERSION}"
|
||||||
|
OPTS+=" --base-path ${BASE_PATH}"
|
||||||
|
OPTS+=" --dataset_name ${DATASET}"
|
||||||
|
OPTS+=" --batch-size 64"
|
||||||
|
OPTS+=" --lr 0.001" # You can use different learning rate to find optimal performance
|
||||||
|
OPTS+=" --max-encoder-length 512"
|
||||||
|
OPTS+=" --train-iters 1400"
|
||||||
|
OPTS+=" --lr-decay-style constant"
|
||||||
|
OPTS+=" --weight-decay 1e-2"
|
||||||
|
OPTS+=" --clip-grad 10.0"
|
||||||
|
OPTS+=" --loss-scale 128"
|
||||||
|
OPTS+=" --delta_type low_rank_adapter" # You can use different delta type, listed in https://opendelta.readthedocs.io/en/latest/notes/acceleration.html#BMTrain
|
||||||
|
|
||||||
|
CMD="python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} ${BASE_PATH}2_with_bmtrain.py ${OPTS}"
|
||||||
|
echo ${CMD}
|
||||||
|
|
||||||
|
${CMD} 2>&1 | tee ${BASE_PATH}/tmp/logs/bmt_bert_boolq_finetune-${VERSION}-${DATASET}.log
|
||||||
|
|
||||||
|
|
|
@ -13,4 +13,23 @@ requirement:
|
||||||
```
|
```
|
||||||
pip install openprompt
|
pip install openprompt
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 2_with_bmtrain.py
|
||||||
|
1. install necessary packages:
|
||||||
|
```
|
||||||
|
pip install git+https://github.com/OpenBMB/BMTrain.git
|
||||||
|
pip install git+git@github.com:OpenBMB/ModelCenter.git
|
||||||
|
```
|
||||||
|
2. download dataset from https://super.gluebenchmark.com/tasks, e.g.,
|
||||||
|
```
|
||||||
|
mkdir down_data
|
||||||
|
cd down_data
|
||||||
|
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/BoolQ.zip
|
||||||
|
unzip BoolQ.zip
|
||||||
|
```
|
||||||
|
3. Run the shell scripts, change `NNODES`,`GPUS_PER_NODE` according to your computational resources.
|
||||||
|
```
|
||||||
|
bash 2_with_bmtrain.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,16 @@ delta2.detach()
|
||||||
# say we add lora to the last four layer of the decoder of t5, with lora rank=5
|
# say we add lora to the last four layer of the decoder of t5, with lora rank=5
|
||||||
delta_config3 = AutoDeltaConfig.from_dict({"delta_type":"lora", "modified_modules":["[r]decoder.*((20)|(21)|(22)|(23)).*DenseReluDense\.wi"], "lora_r":5})
|
delta_config3 = AutoDeltaConfig.from_dict({"delta_type":"lora", "modified_modules":["[r]decoder.*((20)|(21)|(22)|(23)).*DenseReluDense\.wi"], "lora_r":5})
|
||||||
delta3 = AutoDeltaModel.from_config(delta_config3, backbone_model=wrapped_model)
|
delta3 = AutoDeltaModel.from_config(delta_config3, backbone_model=wrapped_model)
|
||||||
|
delta3.freeze_module()
|
||||||
delta3.log()
|
delta3.log()
|
||||||
|
|
||||||
|
|
||||||
|
# add optimizer as normal
|
||||||
|
from transformers import AdamW
|
||||||
|
optimizer = AdamW(wrapped_model.parameters(), lr=3e-3)
|
||||||
|
|
||||||
|
# inspect_optimizer
|
||||||
|
from opendelta.utils.inspect import inspect_optimizer_statistics
|
||||||
|
inspect_optimizer_statistics(optimizer)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,255 +0,0 @@
|
||||||
import time
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.metrics import accuracy_score, recall_score, f1_score
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
|
|
||||||
from model_center import get_args
|
|
||||||
from model_center.model import Bert
|
|
||||||
from model_center.tokenizer import BertTokenizer
|
|
||||||
from model_center.dataset.bertdataset import DATASET
|
|
||||||
from model_center.utils import print_inspect
|
|
||||||
from model_center.layer import Linear
|
|
||||||
from model_center.dataset import DistributedDataLoader
|
|
||||||
|
|
||||||
class BertModel(torch.nn.Module):
|
|
||||||
def __init__(self, args, num_types):
|
|
||||||
super().__init__()
|
|
||||||
self.bert : Bert = Bert.from_pretrained(args.model_config)
|
|
||||||
dim_model = self.bert.input_embedding.dim_model
|
|
||||||
self.dense = Linear(dim_model, num_types)
|
|
||||||
bmt.init_parameters(self.dense)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output
|
|
||||||
logits = self.dense(pooler_output)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def get_tokenizer(args):
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.model_config)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def get_model(args):
|
|
||||||
num_types = {
|
|
||||||
"BoolQ" : 2,
|
|
||||||
"CB" : 3,
|
|
||||||
"COPA" : 1,
|
|
||||||
"RTE" : 2,
|
|
||||||
"WiC" : 2,
|
|
||||||
}
|
|
||||||
model = BertModel(args, num_types[args.dataset_name])
|
|
||||||
return model
|
|
||||||
|
|
||||||
def get_optimizer(args, model):
|
|
||||||
optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=args.weight_decay)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def get_learning_rate_scheduler(args, optimizer):
|
|
||||||
if args.lr_decay_iters is None:
|
|
||||||
args.lr_decay_iters = args.train_iters * args.epochs
|
|
||||||
if args.lr_decay_style == "noam":
|
|
||||||
lr_scheduler = bmt.lr_scheduler.Noam(optimizer,
|
|
||||||
start_lr = args.lr,
|
|
||||||
warmup_iter = args.warmup_iters,
|
|
||||||
end_iter = args.lr_decay_iters,
|
|
||||||
num_iter = args.start_step)
|
|
||||||
elif args.lr_decay_style == "constant":
|
|
||||||
lr_scheduler = bmt.lr_scheduler.NoDecay(optimizer,
|
|
||||||
start_lr = args.lr,
|
|
||||||
warmup_iter = args.warmup_iters,
|
|
||||||
end_iter = -1,
|
|
||||||
num_iter = args.start_step)
|
|
||||||
elif args.lr_decay_style == "linear":
|
|
||||||
lr_scheduler = bmt.lr_scheduler.Linear(optimizer,
|
|
||||||
start_lr = args.lr,
|
|
||||||
warmup_iter = args.warmup_iters,
|
|
||||||
end_iter = args.lr_decay_iters,
|
|
||||||
num_iter = args.start_step)
|
|
||||||
elif args.lr_decay_style == "exponential":
|
|
||||||
lr_scheduler = bmt.lr_scheduler.Exponential(optimizer,
|
|
||||||
start_lr = args.lr,
|
|
||||||
warmup_iter = args.warmup_iters,
|
|
||||||
end_iter = args.lr_decay_iters,
|
|
||||||
num_iter = args.start_step)
|
|
||||||
elif args.lr_decay_style == "cosine":
|
|
||||||
lr_scheduler = bmt.lr_scheduler.Cosine(optimizer,
|
|
||||||
start_lr = args.lr,
|
|
||||||
warmup_iter = args.warmup_iters,
|
|
||||||
end_iter = args.lr_decay_iters,
|
|
||||||
num_iter = args.start_step)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.")
|
|
||||||
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
def setup_model_and_optimizer(args):
|
|
||||||
# get the tokenizer
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
# get the model
|
|
||||||
model = get_model(args)
|
|
||||||
bmt.synchronize()
|
|
||||||
# get the optimizer and lr_scheduler
|
|
||||||
optimizer = get_optimizer(args, model)
|
|
||||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
|
||||||
bmt.synchronize()
|
|
||||||
# get the memory usage
|
|
||||||
bmt.print_rank("Model mem\n", torch.cuda.memory_summary())
|
|
||||||
bmt.synchronize()
|
|
||||||
return tokenizer, model, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
def initialize():
|
|
||||||
# get arguments
|
|
||||||
args = get_args()
|
|
||||||
# init bmt
|
|
||||||
bmt.init_distributed(seed = args.seed)
|
|
||||||
# init save folder
|
|
||||||
if args.save != None:
|
|
||||||
os.makedirs(args.save, exist_ok=True)
|
|
||||||
return args
|
|
||||||
|
|
||||||
def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size):
|
|
||||||
splits = ['train', 'dev', 'test']
|
|
||||||
dataset = {}
|
|
||||||
for split in splits:
|
|
||||||
dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
|
|
||||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
|
||||||
|
|
||||||
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
|
|
||||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
print_inspect(model, '*')
|
|
||||||
|
|
||||||
for epoch in range(12):
|
|
||||||
dataloader = {
|
|
||||||
"train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True),
|
|
||||||
"dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
for it, data in enumerate(dataloader['train']):
|
|
||||||
if args.dataset_name == 'COPA':
|
|
||||||
input_ids0 = data["input_ids0"]
|
|
||||||
attention_mask0 = data["attention_mask0"]
|
|
||||||
token_type_ids0 = data["token_type_ids0"]
|
|
||||||
input_ids1 = data["input_ids1"]
|
|
||||||
attention_mask1 = data["attention_mask1"]
|
|
||||||
token_type_ids1 = data["token_type_ids1"]
|
|
||||||
labels = data["labels"]
|
|
||||||
else:
|
|
||||||
input_ids = data["input_ids"]
|
|
||||||
attention_mask = data["attention_mask"]
|
|
||||||
token_type_ids = data["token_type_ids"]
|
|
||||||
labels = data["labels"]
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
st_time = time.time()
|
|
||||||
|
|
||||||
if args.dataset_name == 'COPA':
|
|
||||||
logits = torch.cat([
|
|
||||||
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
|
|
||||||
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
|
|
||||||
], dim=1)
|
|
||||||
else:
|
|
||||||
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
|
||||||
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
|
||||||
|
|
||||||
global_loss = bmt.sum_loss(loss).item()
|
|
||||||
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
|
|
||||||
optim_manager.backward(loss)
|
|
||||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type = 2)
|
|
||||||
|
|
||||||
optim_manager.step()
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_time = time.time() - st_time
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format(
|
|
||||||
epoch,
|
|
||||||
it,
|
|
||||||
len(dataloader["train"]),
|
|
||||||
global_loss,
|
|
||||||
lr_scheduler.current_lr,
|
|
||||||
int(optim_manager.loss_scale),
|
|
||||||
grad_norm,
|
|
||||||
elapsed_time,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for split in ['dev']:
|
|
||||||
pd = []
|
|
||||||
gt = []
|
|
||||||
for it, data in enumerate(dataloader[split]):
|
|
||||||
if args.dataset_name == 'COPA':
|
|
||||||
input_ids0 = data["input_ids0"]
|
|
||||||
attention_mask0 = data["attention_mask0"]
|
|
||||||
token_type_ids0 = data["token_type_ids0"]
|
|
||||||
input_ids1 = data["input_ids1"]
|
|
||||||
attention_mask1 = data["attention_mask1"]
|
|
||||||
token_type_ids1 = data["token_type_ids1"]
|
|
||||||
labels = data["labels"]
|
|
||||||
logits = torch.cat([
|
|
||||||
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
|
|
||||||
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
|
|
||||||
], dim=1)
|
|
||||||
else:
|
|
||||||
input_ids = data["input_ids"]
|
|
||||||
attention_mask = data["attention_mask"]
|
|
||||||
token_type_ids = data["token_type_ids"]
|
|
||||||
labels = data["labels"]
|
|
||||||
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
|
||||||
|
|
||||||
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
|
||||||
logits = logits.argmax(dim=-1)
|
|
||||||
pd.extend(logits.cpu().tolist())
|
|
||||||
gt.extend(labels.cpu().tolist())
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format(
|
|
||||||
split,
|
|
||||||
epoch,
|
|
||||||
it,
|
|
||||||
len(dataloader[split]),
|
|
||||||
loss,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist()
|
|
||||||
gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist()
|
|
||||||
|
|
||||||
bmt.print_rank(f"{split} epoch {epoch}:")
|
|
||||||
if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]:
|
|
||||||
acc = accuracy_score(gt, pd)
|
|
||||||
bmt.print_rank(f"accuracy: {acc*100:.2f}")
|
|
||||||
if args.dataset_name in ["CB"]:
|
|
||||||
rcl = f1_score(gt, pd, average="macro")
|
|
||||||
f1 = recall_score(gt, pd, average="macro")
|
|
||||||
bmt.print_rank(f"recall: {rcl*100:.2f}")
|
|
||||||
bmt.print_rank(f"Average F1: {f1*100:.2f}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = initialize()
|
|
||||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
|
||||||
dataset = prepare_dataset(
|
|
||||||
args,
|
|
||||||
tokenizer,
|
|
||||||
f"{args.base_path}/down_data/superglue/",
|
|
||||||
args.dataset_name,
|
|
||||||
bmt.rank(), bmt.world_size(),
|
|
||||||
)
|
|
||||||
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -21,6 +21,7 @@ from .delta_models.adapter import AdapterModel
|
||||||
from .delta_models.prefix import PrefixModel
|
from .delta_models.prefix import PrefixModel
|
||||||
from .delta_models.soft_prompt import SoftPromptModel
|
from .delta_models.soft_prompt import SoftPromptModel
|
||||||
from .delta_models.low_rank_adapter import LowRankAdapterModel
|
from .delta_models.low_rank_adapter import LowRankAdapterModel
|
||||||
|
from .delta_models.parallel_adapter import ParallelAdapterModel
|
||||||
from .utils.visualization import Visualization
|
from .utils.visualization import Visualization
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from multiprocessing.sharedctypes import Value
|
||||||
import os
|
import os
|
||||||
from turtle import back
|
from turtle import back
|
||||||
from opendelta.delta_configs import BaseDeltaConfig
|
from opendelta.delta_configs import BaseDeltaConfig
|
||||||
|
from opendelta.utils.inspect import inspect_module_statistics
|
||||||
from opendelta.utils.model_md5 import gen_model_hash
|
from opendelta.utils.model_md5 import gen_model_hash
|
||||||
from opendelta.utils.signature import get_arg_names, signature
|
from opendelta.utils.signature import get_arg_names, signature
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
@ -27,6 +28,7 @@ from opendelta.utils.cuda import move_dict_to_cuda
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from opendelta.utils.data_parallel import caller_map
|
from opendelta.utils.data_parallel import caller_map
|
||||||
|
from opendelta.utils.backend import BackendMapping
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def is_leaf_module(module):
|
def is_leaf_module(module):
|
||||||
|
@ -94,6 +96,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
config_class = BaseDeltaConfig
|
config_class = BaseDeltaConfig
|
||||||
default_unfrozen_modules = ["deltas"]
|
default_unfrozen_modules = ["deltas"]
|
||||||
_need_pseudo_data = True
|
_need_pseudo_data = True
|
||||||
|
_supported_backends = ['hf']
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
modified_modules: Optional[List[str]] = None,
|
modified_modules: Optional[List[str]] = None,
|
||||||
|
@ -101,7 +104,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
unfrozen_modules: Optional[List[str]] = None,
|
unfrozen_modules: Optional[List[str]] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
common_structure: Optional[bool] = False,
|
common_structure: Optional[bool] = False,
|
||||||
framework_type: Optional[str]= "hf", # select from ["hf", "bmt"]
|
backend: Optional[str]= "hf", # select from ["hf", "bmt"]
|
||||||
):
|
):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
|
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
|
||||||
|
@ -139,7 +142,10 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
self.unfrozen_modules = self.default_unfrozen_modules
|
self.unfrozen_modules = self.default_unfrozen_modules
|
||||||
if self.common_structure and self.structure_mapping is None:
|
if self.common_structure and self.structure_mapping is None:
|
||||||
raise RuntimeError("Using common structure but the structure mapping is None")
|
raise RuntimeError("Using common structure but the structure mapping is None")
|
||||||
self.framework_type = framework_type
|
if backend not in self._supported_backends:
|
||||||
|
raise RuntimeError("Currently, backend `{}` is not supported for `{}`".format(backend, self.__class__.__name__))
|
||||||
|
self.backend = backend
|
||||||
|
self.backend_mapping = BackendMapping(backend)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> RuntimeError:
|
def forward(self, *args, **kwargs) -> RuntimeError:
|
||||||
r"""
|
r"""
|
||||||
|
@ -371,10 +377,11 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
_auto_dummy_fail = False
|
_auto_dummy_fail = False
|
||||||
try:
|
try:
|
||||||
module(**dummy_inputs)
|
module(**dummy_inputs)
|
||||||
except:
|
except Exception as e:
|
||||||
_auto_dummy_fail = True
|
_auto_dummy_fail = True
|
||||||
if _auto_dummy_fail:
|
|
||||||
raise AttributeError(f"\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
|
if _auto_dummy_fail and _auto_dummy:
|
||||||
|
raise AttributeError(f"str({e})\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -684,65 +691,16 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
from opendelta import Visualization
|
from opendelta import Visualization
|
||||||
Visualization(module).structure_graph()
|
Visualization(module).structure_graph()
|
||||||
|
|
||||||
self.get_statistics(module)
|
self.stat = inspect_module_statistics(module, verbose=False)
|
||||||
if trainable_ratio:
|
if trainable_ratio:
|
||||||
logger.info("Trainable Ratio: {:2f}%".format(self.stat['trainable_ratio']*100))
|
logger.info("Trainable Ratio: {}/{}={:.6f}%".format(self.stat['trainable_parameters'], self.stat['total_parameters'], self.stat['trainable_ratio']*100))
|
||||||
if delta_ratio:
|
if delta_ratio:
|
||||||
logger.info("Delta Parameter Ratio: {:2f}%".format(self.stat['delta_ratio']*100))
|
logger.info("Delta Parameter Ratio: {}/{}={:.6f}%".format(self.stat['delta_parameters'], self.stat['total_parameters'],self.stat['delta_ratio']*100))
|
||||||
if cuda_memory:
|
if cuda_memory:
|
||||||
logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem']))
|
logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem']))
|
||||||
|
|
||||||
|
|
||||||
def get_statistics(self, module=None):
|
|
||||||
r"""Get the statistics of the parameters in the delta modules.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`dict`: The statistics of the parameters in the delta modules.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if module is None:
|
|
||||||
module = self.backbone_model
|
|
||||||
|
|
||||||
self.stat = {}
|
|
||||||
n_trainable = self.num_trainable_parameters(module)
|
|
||||||
n_total = self.num_total_parameters(module)
|
|
||||||
|
|
||||||
self.stat['trainable_ratio'] = n_trainable/n_total
|
|
||||||
|
|
||||||
n_delta = self.num_delta_parameters(module)
|
|
||||||
n_total = self.num_total_parameters(module)
|
|
||||||
self.stat['delta_ratio'] = n_delta/n_total
|
|
||||||
|
|
||||||
cudamem = 0
|
|
||||||
maxcudamem = 0
|
|
||||||
for device_id in range(torch.cuda.device_count()):
|
|
||||||
cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
|
|
||||||
maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
|
|
||||||
self.stat['cudamem'] = cudamem
|
|
||||||
self.stat['maxcudamem'] = maxcudamem
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def num_delta_parameters(self, module: Optional[nn.Module]=None):
|
|
||||||
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
|
|
||||||
compute the trainable rate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`List[nn.Parameter]`
|
|
||||||
"""
|
|
||||||
if module is None:
|
|
||||||
module = self.backbone_model
|
|
||||||
pnum_tot = 0
|
|
||||||
for param in module.parameters():
|
|
||||||
if hasattr(param, "_is_delta"):
|
|
||||||
pnum_tot += param.numel()
|
|
||||||
return pnum_tot
|
|
||||||
|
|
||||||
# Two functions for plug and remove the delta model.
|
# Two functions for plug and remove the delta model.
|
||||||
def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
|
def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
|
||||||
|
|
|
@ -60,13 +60,14 @@ class AdapterLayer(nn.Module, InterFaceMixin):
|
||||||
def get_layer_count(cls):
|
def get_layer_count(cls):
|
||||||
return cls.layer_count
|
return cls.layer_count
|
||||||
|
|
||||||
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None):
|
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None, backend="hf"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
InterFaceMixin.__init__(self)
|
InterFaceMixin.__init__(self)
|
||||||
self.bottleneck_dim = bottleneck_dim
|
self.bottleneck_dim = bottleneck_dim
|
||||||
self.init_device = device
|
self.init_device = device
|
||||||
self.instantiated = False
|
self.instantiated = False
|
||||||
self.non_linearity = non_linearity
|
self.non_linearity = non_linearity
|
||||||
|
self.backend=backend
|
||||||
|
|
||||||
self.layer_id = AdapterLayer.get_layer_count()
|
self.layer_id = AdapterLayer.get_layer_count()
|
||||||
AdapterLayer.count_layer()
|
AdapterLayer.count_layer()
|
||||||
|
@ -79,14 +80,16 @@ class AdapterLayer(nn.Module, InterFaceMixin):
|
||||||
else:
|
else:
|
||||||
return self.init_device
|
return self.init_device
|
||||||
|
|
||||||
def instantiate(self, hidden_dim):
|
def instantiate(self, hiddens):
|
||||||
|
self.hidden_dim = hiddens.shape[-1]
|
||||||
|
self.hidden_dtype = hiddens.dtype
|
||||||
self.modulelist = nn.Sequential()
|
self.modulelist = nn.Sequential()
|
||||||
self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.init_device))
|
self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.init_device, dtype=self.hidden_dtype))
|
||||||
|
|
||||||
# select non-linearity
|
# select non-linearity
|
||||||
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
|
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
|
||||||
|
|
||||||
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device))
|
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device, dtype=self.hidden_dtype))
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# If we want to have a layer norm on output, we apply it later after a separate residual connection
|
# If we want to have a layer norm on output, we apply it later after a separate residual connection
|
||||||
|
@ -97,11 +100,9 @@ class AdapterLayer(nn.Module, InterFaceMixin):
|
||||||
self.instantiated = True
|
self.instantiated = True
|
||||||
# initialize the weight, which is important for fast convergence and better performance.
|
# initialize the weight, which is important for fast convergence and better performance.
|
||||||
self.apply(self._init_weight)
|
self.apply(self._init_weight)
|
||||||
try:
|
if self.backend == 'bmt':
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
|
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _init_weight(self, module):
|
def _init_weight(self, module):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
|
@ -123,19 +124,29 @@ class AdapterLayer(nn.Module, InterFaceMixin):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
hiddens = self._transpose(hiddens)
|
hiddens = self._transpose(hiddens)
|
||||||
hiddens = self._convert_data_type(hiddens)
|
# if self.backend == 'hf':
|
||||||
|
# hiddens = self._convert_data_type(hiddens)
|
||||||
|
# elif self.backend == 'bmt': # if bmt, left the convertion to bmt
|
||||||
|
# pass
|
||||||
|
|
||||||
if not self.instantiated:
|
if not self.instantiated:
|
||||||
self.hidden_dim = hiddens.shape[-1]
|
# self.hidden_dim = hiddens.shape[-1]
|
||||||
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
# logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
||||||
self.instantiate(hidden_dim=self.hidden_dim)
|
self.instantiate(hiddens=hiddens)
|
||||||
|
|
||||||
|
|
||||||
|
# from IPython import embed; embed(header="14135315")
|
||||||
adapter_output = self.modulelist(hiddens)
|
adapter_output = self.modulelist(hiddens)
|
||||||
modified_output = adapter_output + hiddens # TODO option: disable residual_connection
|
modified_output = adapter_output + hiddens # TODO option: disable residual_connection
|
||||||
|
|
||||||
modified_output = self._reverse_transpose(modified_output)
|
modified_output = self._reverse_transpose(modified_output)
|
||||||
modified_output = self._reverse_data_type(modified_output)
|
|
||||||
|
# if self.backend == 'hf':
|
||||||
|
# # print("!"*100)
|
||||||
|
# modified_output = self._reverse_data_type(modified_output)
|
||||||
|
# elif self.backend == 'bmt': # if bmt, left the convertion to bmt
|
||||||
|
# print("!"*100)
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
output = (modified_output,) + output[1:]
|
output = (modified_output,) + output[1:]
|
||||||
|
@ -189,20 +200,24 @@ class AdapterModel(DeltaBase):
|
||||||
modified_modules (:obj:`List[str]`): modules to add adapter after them.
|
modified_modules (:obj:`List[str]`): modules to add adapter after them.
|
||||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters.
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters.
|
||||||
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
||||||
|
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_class = AdapterConfig
|
config_class = AdapterConfig
|
||||||
delta_type = "adapter"
|
delta_type = "adapter"
|
||||||
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
||||||
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = True
|
_need_pseudo_data = True
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
bottleneck_dim: Optional[int]=24,
|
bottleneck_dim: Optional[int]=24,
|
||||||
non_linearity: Optional[str]='gelu_new',
|
non_linearity: Optional[str]='gelu_new',
|
||||||
modified_modules: Optional[bool] = None,
|
modified_modules: Optional[List[str]] = None,
|
||||||
|
exclude_modules: Optional[List[str]] = None,
|
||||||
unfrozen_modules: Optional[bool] = None,
|
unfrozen_modules: Optional[bool] = None,
|
||||||
common_structure: Optional[bool] = None,
|
common_structure: Optional[bool] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
|
backend: Optional[str] = 'hf',
|
||||||
):
|
):
|
||||||
DeltaBase.__init__(self,
|
DeltaBase.__init__(self,
|
||||||
backbone_model,
|
backbone_model,
|
||||||
|
@ -211,6 +226,7 @@ class AdapterModel(DeltaBase):
|
||||||
unfrozen_modules=unfrozen_modules,
|
unfrozen_modules=unfrozen_modules,
|
||||||
common_structure=common_structure,
|
common_structure=common_structure,
|
||||||
interactive_modify=interactive_modify,
|
interactive_modify=interactive_modify,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
arg_names = get_arg_names_inside_func(self.__init__)
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
for arg_name in arg_names:
|
for arg_name in arg_names:
|
||||||
|
@ -231,6 +247,6 @@ class AdapterModel(DeltaBase):
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
|
adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend)
|
||||||
self.delta_modules.append(adapterlayer)
|
self.delta_modules.append(adapterlayer)
|
||||||
return adapterlayer
|
return adapterlayer
|
||||||
|
|
|
@ -75,16 +75,6 @@ class BiasLayer(nn.Module):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
return output
|
return output
|
||||||
|
|
||||||
framework_map = {}
|
|
||||||
framework_map['hf'] = {
|
|
||||||
"linear": nn.Linear,
|
|
||||||
"layer_norm": nn.LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
framework_map['bmt'] = {
|
|
||||||
"linear": model_center.layer.Linear,
|
|
||||||
"layer_norm", model_center.layer.LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BitFitModel(DeltaBase):
|
class BitFitModel(DeltaBase):
|
||||||
|
@ -124,6 +114,7 @@ class BitFitModel(DeltaBase):
|
||||||
config_class = BitFitConfig
|
config_class = BitFitConfig
|
||||||
delta_type = "bitfit"
|
delta_type = "bitfit"
|
||||||
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
|
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
|
||||||
|
_supported_backends = ['hf']
|
||||||
_need_pseudo_data = False
|
_need_pseudo_data = False
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -132,7 +123,7 @@ class BitFitModel(DeltaBase):
|
||||||
unfrozen_modules: Optional[List[str]] = None,
|
unfrozen_modules: Optional[List[str]] = None,
|
||||||
common_structure: Optional[bool] = None,
|
common_structure: Optional[bool] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
framework_type: Optional[str] = "hf",
|
backend: Optional[str] = "hf",
|
||||||
):
|
):
|
||||||
DeltaBase.__init__(self,
|
DeltaBase.__init__(self,
|
||||||
backbone_model,
|
backbone_model,
|
||||||
|
@ -141,7 +132,7 @@ class BitFitModel(DeltaBase):
|
||||||
unfrozen_modules=unfrozen_modules,
|
unfrozen_modules=unfrozen_modules,
|
||||||
common_structure=common_structure,
|
common_structure=common_structure,
|
||||||
interactive_modify=interactive_modify,
|
interactive_modify=interactive_modify,
|
||||||
framework_type=framework_type,
|
backend=backend,
|
||||||
)
|
)
|
||||||
arg_names = get_arg_names_inside_func(self.__init__)
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
for arg_name in arg_names:
|
for arg_name in arg_names:
|
||||||
|
@ -153,6 +144,8 @@ class BitFitModel(DeltaBase):
|
||||||
|
|
||||||
self.add_all_delta_to_backbone(self.backbone_model,
|
self.add_all_delta_to_backbone(self.backbone_model,
|
||||||
self.modified_modules)
|
self.modified_modules)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def update_module(self, module: nn.Module, key: str):
|
def update_module(self, module: nn.Module, key: str):
|
||||||
|
@ -167,7 +160,8 @@ class BitFitModel(DeltaBase):
|
||||||
# if it is a leaf module, add bias to it regardless of its type.
|
# if it is a leaf module, add bias to it regardless of its type.
|
||||||
# if self.check_linear(module):
|
# if self.check_linear(module):
|
||||||
# self.add_bias_to_linear(module)
|
# self.add_bias_to_linear(module)
|
||||||
if self.check_linear(module) or self.check_layernorm(module, nn.LayerNorm):
|
if self.backend_mapping.check_type(module, 'linear') or \
|
||||||
|
self.backend_mapping.check_type(module, 'layer_norm'):
|
||||||
self.add_bias_to_modules_have_bias_or_known_type(module)
|
self.add_bias_to_modules_have_bias_or_known_type(module)
|
||||||
else:
|
else:
|
||||||
# for example, layer_norms, lm_heads.
|
# for example, layer_norms, lm_heads.
|
||||||
|
@ -202,48 +196,27 @@ class BitFitModel(DeltaBase):
|
||||||
c.bias.requires_grad = True
|
c.bias.requires_grad = True
|
||||||
self.delta_params.append(c.bias)
|
self.delta_params.append(c.bias)
|
||||||
else:
|
else:
|
||||||
if self.check_linear(c) or isinstance(c): # todo: bmt layerNorm
|
if self.backend_mapping.check_type(c, 'linear') or \
|
||||||
|
self.backend_mapping.check_type(c, 'layer_norm'):
|
||||||
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
|
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
|
||||||
|
|
||||||
self._reset_bias_parameters(c) #?
|
self._reset_bias_parameters(c)
|
||||||
try:
|
if self.backend == 'bmt':
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
bias = bmt.BMTrainModelWrapper(bias)
|
bias = bmt.BMTrainModelWrapper(bias)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
c.register_parameter('bias', bias)
|
c.register_parameter('bias', bias)
|
||||||
self.delta_params.append(bias)
|
self.delta_params.append(bias)
|
||||||
|
|
||||||
def add_bias_to_others(self, c): # todo: bmtrain?
|
def add_bias_to_others(self, c):
|
||||||
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c))
|
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c)) # TODO: bmtrain?
|
||||||
|
if self.backend == 'bmt':
|
||||||
|
import bmtrain as bmt
|
||||||
|
new_bias = bmt.BMTrainModelWrapper(new_bias)
|
||||||
|
|
||||||
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
|
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
|
||||||
self.delta_modules.append(new_bias)
|
self.delta_modules.append(new_bias)
|
||||||
|
|
||||||
def check_linear(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from model_center.layer import Linear
|
|
||||||
if isinstance(m, Linear):
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_layernorm(self, m):
|
|
||||||
if isinstance(m, nn.LayerNorm):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from model_center.layer import LayerNorm
|
|
||||||
if isinstance(m, LayerNorm):
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_bias_parameters(linear_module):
|
def _reset_bias_parameters(linear_module):
|
||||||
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
|
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
|
||||||
|
|
|
@ -36,6 +36,7 @@ class HyperComplexAdapterLayer(nn.Module):
|
||||||
device=None,
|
device=None,
|
||||||
use_bias_up_sampler=True,
|
use_bias_up_sampler=True,
|
||||||
use_bias_down_sampler=True,
|
use_bias_down_sampler=True,
|
||||||
|
backend = 'hf',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reduction_factor = reduction_factor
|
self.reduction_factor = reduction_factor
|
||||||
|
@ -55,14 +56,17 @@ class HyperComplexAdapterLayer(nn.Module):
|
||||||
self.use_bias_up_sampler=use_bias_up_sampler
|
self.use_bias_up_sampler=use_bias_up_sampler
|
||||||
self.use_bias_down_sampler=use_bias_down_sampler
|
self.use_bias_down_sampler=use_bias_down_sampler
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
self.instantiated = False
|
self.instantiated = False
|
||||||
|
|
||||||
|
|
||||||
def instantiate(self, hidden_dim):
|
def instantiate(self, hiddens):
|
||||||
self.down_sample_size = hidden_dim // self.reduction_factor
|
self.hidden_dim = hiddens.shape[-1]
|
||||||
|
self.hidden_dtype = hiddens.dtype
|
||||||
|
self.down_sample_size = self.hidden_dim // self.reduction_factor
|
||||||
self.activation = Activations(self.non_linearity.lower()).to(self.device)
|
self.activation = Activations(self.non_linearity.lower()).to(self.device)
|
||||||
self.down_sampler = PHMLinear(in_features=hidden_dim,
|
self.down_sampler = PHMLinear(in_features=self.hidden_dim,
|
||||||
out_features=self.down_sample_size,
|
out_features=self.down_sample_size,
|
||||||
bias=self.use_bias_down_sampler,
|
bias=self.use_bias_down_sampler,
|
||||||
c_init=self.phm_c_init,
|
c_init=self.phm_c_init,
|
||||||
|
@ -76,9 +80,10 @@ class HyperComplexAdapterLayer(nn.Module):
|
||||||
factorized_phm_rule=self.factorized_phm_rule,
|
factorized_phm_rule=self.factorized_phm_rule,
|
||||||
phm_rank=self.phm_rank,
|
phm_rank=self.phm_rank,
|
||||||
phm_init_range=self.phm_init_range,
|
phm_init_range=self.phm_init_range,
|
||||||
kronecker_prod=self.kronecker_prod).to(self.device)
|
kronecker_prod=self.kronecker_prod,
|
||||||
|
dtype = self.hidden_dtype).to(self.device)
|
||||||
self.up_sampler = PHMLinear(in_features=self.down_sample_size,
|
self.up_sampler = PHMLinear(in_features=self.down_sample_size,
|
||||||
out_features=hidden_dim,
|
out_features=self.hidden_dim,
|
||||||
bias=self.use_bias_up_sampler,
|
bias=self.use_bias_up_sampler,
|
||||||
c_init=self.phm_c_init,
|
c_init=self.phm_c_init,
|
||||||
phm_dim=self.hypercomplex_division,
|
phm_dim=self.hypercomplex_division,
|
||||||
|
@ -91,15 +96,14 @@ class HyperComplexAdapterLayer(nn.Module):
|
||||||
factorized_phm_rule=self.factorized_phm_rule,
|
factorized_phm_rule=self.factorized_phm_rule,
|
||||||
phm_rank=self.phm_rank,
|
phm_rank=self.phm_rank,
|
||||||
phm_init_range=self.phm_init_range,
|
phm_init_range=self.phm_init_range,
|
||||||
kronecker_prod=self.kronecker_prod).to(self.device)
|
kronecker_prod=self.kronecker_prod,
|
||||||
|
dtype = self.hidden_dtype).to(self.device)
|
||||||
self.instantiated = True
|
self.instantiated = True
|
||||||
try:
|
if self.backend == "bmt":
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
self.activation = bmt.BMTrainModelWrapper(self.activation)
|
self.activation = bmt.BMTrainModelWrapper(self.activation)
|
||||||
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
|
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
|
||||||
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
|
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def post_forward(self, output):
|
def post_forward(self, output):
|
||||||
|
@ -116,9 +120,7 @@ class HyperComplexAdapterLayer(nn.Module):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
if not self.instantiated:
|
if not self.instantiated:
|
||||||
self.hidden_dim = hiddens.shape[-1]
|
self.instantiate(hiddens=hiddens)
|
||||||
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
|
||||||
self.instantiate(hidden_dim=self.hidden_dim)
|
|
||||||
|
|
||||||
|
|
||||||
z = self.down_sampler(hiddens)
|
z = self.down_sampler(hiddens)
|
||||||
|
@ -193,6 +195,7 @@ class CompacterModel(DeltaBase):
|
||||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
||||||
together with the prefix parameters.
|
together with the prefix parameters.
|
||||||
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): whether using name-based addressing with a common structure mapping.
|
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): whether using name-based addressing with a common structure mapping.
|
||||||
|
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
|
||||||
reduction_factor (:obj:`int`, *optional*, default to ``16``): bottleneck_dim = hidden_dim//reduction_factor
|
reduction_factor (:obj:`int`, *optional*, default to ``16``): bottleneck_dim = hidden_dim//reduction_factor
|
||||||
non_linearity (:obj:`str`, *optional*, default to ``"gelu_new"``): The non linearity activation used in between the down
|
non_linearity (:obj:`str`, *optional*, default to ``"gelu_new"``): The non linearity activation used in between the down
|
||||||
projecter and the up projecter.
|
projecter and the up projecter.
|
||||||
|
@ -218,6 +221,7 @@ class CompacterModel(DeltaBase):
|
||||||
config_class = CompacterConfig
|
config_class = CompacterConfig
|
||||||
delta_type = "compacter"
|
delta_type = "compacter"
|
||||||
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
||||||
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = True
|
_need_pseudo_data = True
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model,
|
backbone_model,
|
||||||
|
@ -226,6 +230,7 @@ class CompacterModel(DeltaBase):
|
||||||
unfrozen_modules: Optional[List[str]] = None,
|
unfrozen_modules: Optional[List[str]] = None,
|
||||||
common_structure: Optional[bool] = None,
|
common_structure: Optional[bool] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
|
backend: Optional[str] = 'hf',
|
||||||
reduction_factor=16,
|
reduction_factor=16,
|
||||||
non_linearity="gelu_new",
|
non_linearity="gelu_new",
|
||||||
phm_c_init="normal",
|
phm_c_init="normal",
|
||||||
|
@ -288,22 +293,6 @@ class CompacterModel(DeltaBase):
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor,
|
adapterlayer = HyperComplexAdapterLayer(reduction_factor=self.reduction_factor, non_linearity=self.non_linearity, phm_c_init=self.phm_c_init, hypercomplex_division=self.hypercomplex_division, learn_phm=self.learn_phm, hypercomplex_nonlinearity=self.hypercomplex_nonlinearity, shared_phm_rule=self.shared_phm_rule, factorized_phm=self.factorized_phm, shared_W_phm=self.shared_W_phm, factorized_phm_rule=self.factorized_phm_rule, phm_rank=self.phm_rank, phm_init_range=self.phm_init_range, kronecker_prod=self.kronecker_prod, use_bias_up_sampler=self.use_bias_up_sampler, use_bias_down_sampler=self.use_bias_down_sampler, device=module_device, backend=self.backend)
|
||||||
non_linearity=self.non_linearity,
|
|
||||||
phm_c_init=self.phm_c_init,
|
|
||||||
hypercomplex_division=self.hypercomplex_division,
|
|
||||||
learn_phm=self.learn_phm,
|
|
||||||
hypercomplex_nonlinearity=self.hypercomplex_nonlinearity,
|
|
||||||
shared_phm_rule=self.shared_phm_rule,
|
|
||||||
factorized_phm=self.factorized_phm,
|
|
||||||
shared_W_phm=self.shared_W_phm,
|
|
||||||
factorized_phm_rule=self.factorized_phm_rule,
|
|
||||||
phm_rank=self.phm_rank,
|
|
||||||
phm_init_range=self.phm_init_range,
|
|
||||||
kronecker_prod=self.kronecker_prod,
|
|
||||||
use_bias_up_sampler=self.use_bias_up_sampler,
|
|
||||||
use_bias_down_sampler=self.use_bias_down_sampler,
|
|
||||||
device=module_device
|
|
||||||
)
|
|
||||||
self.delta_modules.append(adapterlayer)
|
self.delta_modules.append(adapterlayer)
|
||||||
return adapterlayer
|
return adapterlayer
|
||||||
|
|
|
@ -84,7 +84,8 @@ class PHMLinear(torch.nn.Module):
|
||||||
factorized_phm_rule=False,
|
factorized_phm_rule=False,
|
||||||
phm_rank = 1,
|
phm_rank = 1,
|
||||||
phm_init_range=0.0001,
|
phm_init_range=0.0001,
|
||||||
kronecker_prod=False) -> None:
|
kronecker_prod=False,
|
||||||
|
dtype=torch.float) -> None:
|
||||||
super(PHMLinear, self).__init__()
|
super(PHMLinear, self).__init__()
|
||||||
assert w_init in ["phm", "glorot-normal", "glorot-uniform", "normal"]
|
assert w_init in ["phm", "glorot-normal", "glorot-uniform", "normal"]
|
||||||
assert c_init in ["normal", "uniform"]
|
assert c_init in ["normal", "uniform"]
|
||||||
|
@ -104,12 +105,12 @@ class PHMLinear(torch.nn.Module):
|
||||||
self.factorized_phm_rule = factorized_phm_rule
|
self.factorized_phm_rule = factorized_phm_rule
|
||||||
if not self.shared_phm_rule:
|
if not self.shared_phm_rule:
|
||||||
if self.factorized_phm_rule:
|
if self.factorized_phm_rule:
|
||||||
self.phm_rule_left = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, 1),
|
self.phm_rule_left = nn.Parameter(torch.empty((phm_dim, phm_dim, 1), dtype=dtype),
|
||||||
requires_grad=learn_phm)
|
requires_grad=learn_phm)
|
||||||
self.phm_rule_right = nn.Parameter(torch.FloatTensor(phm_dim, 1, phm_dim),
|
self.phm_rule_right = nn.Parameter(torch.empty((phm_dim, 1, phm_dim), dtype=dtype),
|
||||||
requires_grad=learn_phm)
|
requires_grad=learn_phm)
|
||||||
else:
|
else:
|
||||||
self.phm_rule = nn.Parameter(torch.FloatTensor(phm_dim, phm_dim, phm_dim),
|
self.phm_rule = nn.Parameter(torch.empty((phm_dim, phm_dim, phm_dim), dtype=dtype),
|
||||||
requires_grad=learn_phm)
|
requires_grad=learn_phm)
|
||||||
self.bias_flag = bias
|
self.bias_flag = bias
|
||||||
self.w_init = w_init
|
self.w_init = w_init
|
||||||
|
@ -118,15 +119,15 @@ class PHMLinear(torch.nn.Module):
|
||||||
self.factorized_phm = factorized_phm
|
self.factorized_phm = factorized_phm
|
||||||
if not self.shared_W_phm:
|
if not self.shared_W_phm:
|
||||||
if self.factorized_phm:
|
if self.factorized_phm:
|
||||||
self.W_left = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self.phm_rank)),
|
self.W_left = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self.phm_rank), dtype=dtype),
|
||||||
requires_grad=True)
|
requires_grad=True)
|
||||||
self.W_right = nn.Parameter(torch.Tensor(size=(phm_dim, self.phm_rank, self._out_feats_per_axis)),
|
self.W_right = nn.Parameter(torch.empty((phm_dim, self.phm_rank, self._out_feats_per_axis), dtype=dtype),
|
||||||
requires_grad=True)
|
requires_grad=True)
|
||||||
else:
|
else:
|
||||||
self.W = nn.Parameter(torch.Tensor(size=(phm_dim, self._in_feats_per_axis, self._out_feats_per_axis)),
|
self.W = nn.Parameter(torch.empty((phm_dim, self._in_feats_per_axis, self._out_feats_per_axis), dtype=dtype),
|
||||||
requires_grad=True)
|
requires_grad=True)
|
||||||
if self.bias_flag:
|
if self.bias_flag:
|
||||||
self.b = nn.Parameter(torch.Tensor(out_features))
|
self.b = nn.Parameter(torch.empty(out_features, dtype=dtype), requires_grad=True)
|
||||||
else:
|
else:
|
||||||
self.register_parameter("b", None)
|
self.register_parameter("b", None)
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
|
@ -6,17 +6,17 @@ from opendelta.delta_models.layers.init import glorot_uniform, glorot_normal
|
||||||
|
|
||||||
class LowRankLinear(torch.nn.Module):
|
class LowRankLinear(torch.nn.Module):
|
||||||
def __init__(self, input_dim: int, output_dim: int, rank: int = 1,
|
def __init__(self, input_dim: int, output_dim: int, rank: int = 1,
|
||||||
bias: bool = True, w_init: str = "glorot-uniform"):
|
bias: bool = True, w_init: str = "glorot-uniform", dtype=torch.float):
|
||||||
super(LowRankLinear, self).__init__()
|
super(LowRankLinear, self).__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
self.w_init = w_init
|
self.w_init = w_init
|
||||||
self.W_left = nn.Parameter(torch.Tensor(size=(input_dim, rank)), requires_grad=True)
|
self.W_left = nn.Parameter(torch.empty((input_dim, rank), dtype=dtype),requires_grad=True)
|
||||||
self.W_right = nn.Parameter(torch.Tensor(size=(rank, output_dim)), requires_grad=True)
|
self.W_right = nn.Parameter(torch.empty((rank, output_dim), dtype=dtype), requires_grad=True)
|
||||||
if bias:
|
if bias:
|
||||||
self.b = nn.Parameter(torch.Tensor(output_dim))
|
self.b = nn.Parameter(torch.empty(output_dim, dtype=dtype))
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
|
|
@ -97,12 +97,14 @@ class LoraModel(DeltaBase):
|
||||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
||||||
together with the prefix parameters.
|
together with the prefix parameters.
|
||||||
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
|
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
|
||||||
|
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = LoraConfig
|
config_class = LoraConfig
|
||||||
delta_type = "lora"
|
delta_type = "lora"
|
||||||
default_modified_modules = ['attn@.q@', 'attn@.v@']
|
default_modified_modules = ['attn@.q@', 'attn@.v@']
|
||||||
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = False
|
_need_pseudo_data = False
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -114,6 +116,7 @@ class LoraModel(DeltaBase):
|
||||||
exclude_modules: Optional[List[str]] = None,
|
exclude_modules: Optional[List[str]] = None,
|
||||||
common_structure: Optional[bool] = None,
|
common_structure: Optional[bool] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
|
backend: Optional[str] = "hf",
|
||||||
):
|
):
|
||||||
DeltaBase.__init__(self,
|
DeltaBase.__init__(self,
|
||||||
backbone_model,
|
backbone_model,
|
||||||
|
@ -121,6 +124,7 @@ class LoraModel(DeltaBase):
|
||||||
unfrozen_modules=unfrozen_modules,
|
unfrozen_modules=unfrozen_modules,
|
||||||
common_structure=common_structure,
|
common_structure=common_structure,
|
||||||
interactive_modify=interactive_modify,
|
interactive_modify=interactive_modify,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
arg_names = get_arg_names_inside_func(self.__init__)
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
for arg_name in arg_names:
|
for arg_name in arg_names:
|
||||||
|
@ -151,10 +155,9 @@ class LoraModel(DeltaBase):
|
||||||
r=self.lora_r,
|
r=self.lora_r,
|
||||||
lora_alpha=self.lora_alpha,
|
lora_alpha=self.lora_alpha,
|
||||||
lora_dropout=self.lora_dropout)
|
lora_dropout=self.lora_dropout)
|
||||||
try:
|
if self.backend == "bmt":
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
new_module = bmt.BMTrainModelWrapper(new_module)
|
new_module = bmt.BMTrainModelWrapper(new_module)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
self.delta_modules.append(new_module)
|
self.delta_modules.append(new_module)
|
||||||
return new_module
|
return new_module
|
||||||
|
|
|
@ -47,7 +47,8 @@ class LowRankAdapter(nn.Module):
|
||||||
non_linearity="gelu_new",
|
non_linearity="gelu_new",
|
||||||
low_rank_w_init="glorot-uniform",
|
low_rank_w_init="glorot-uniform",
|
||||||
low_rank_rank=1,
|
low_rank_rank=1,
|
||||||
device=None):
|
device=None,
|
||||||
|
backend='hf'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reduction_factor = reduction_factor
|
self.reduction_factor = reduction_factor
|
||||||
self.non_linearity = non_linearity
|
self.non_linearity = non_linearity
|
||||||
|
@ -55,27 +56,31 @@ class LowRankAdapter(nn.Module):
|
||||||
self.low_rank_rank = low_rank_rank
|
self.low_rank_rank = low_rank_rank
|
||||||
self.device = device
|
self.device = device
|
||||||
self.instantiated = False
|
self.instantiated = False
|
||||||
|
self.backend=backend
|
||||||
|
|
||||||
|
|
||||||
def instantiate(self, hidden_dim):
|
def instantiate(self, hiddens):
|
||||||
|
self.hidden_dim = hiddens.shape[-1]
|
||||||
|
self.hidden_dtype = hiddens.dtype
|
||||||
|
|
||||||
self.down_sample_size = hidden_dim // self.reduction_factor
|
self.down_sample_size = self.hidden_dim // self.reduction_factor
|
||||||
self.activation = Activations(self.non_linearity.lower()).to(self.device)
|
self.activation = Activations(self.non_linearity.lower()).to(self.device)
|
||||||
self.down_sampler = LowRankLinear(hidden_dim, self.down_sample_size,
|
self.down_sampler = LowRankLinear(self.hidden_dim, self.down_sample_size,
|
||||||
w_init=self.low_rank_w_init,
|
w_init=self.low_rank_w_init,
|
||||||
rank=self.low_rank_rank).to(self.device)
|
rank=self.low_rank_rank,
|
||||||
self.up_sampler = LowRankLinear(self.down_sample_size, hidden_dim,
|
dtype=self.hidden_dtype).to(self.device)
|
||||||
|
self.up_sampler = LowRankLinear(self.down_sample_size, self.hidden_dim,
|
||||||
w_init=self.low_rank_w_init,
|
w_init=self.low_rank_w_init,
|
||||||
rank=self.low_rank_rank).to(self.device)
|
rank=self.low_rank_rank,
|
||||||
|
dtype=self.hidden_dtype).to(self.device)
|
||||||
|
|
||||||
self.instantiated = True
|
self.instantiated = True
|
||||||
try:
|
if self.backend == 'bmt':
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
self.activation = bmt.BMTrainModelWrapper(self.activation)
|
self.activation = bmt.BMTrainModelWrapper(self.activation)
|
||||||
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
|
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
|
||||||
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
|
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def post_forward(self, output):
|
def post_forward(self, output):
|
||||||
r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter,
|
r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter,
|
||||||
|
@ -91,9 +96,7 @@ class LowRankAdapter(nn.Module):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
if not self.instantiated:
|
if not self.instantiated:
|
||||||
self.hidden_dim = hiddens.shape[-1]
|
self.instantiate(hiddens = hiddens)
|
||||||
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
|
||||||
self.instantiate(hidden_dim=self.hidden_dim)
|
|
||||||
|
|
||||||
z = self.down_sampler(hiddens)
|
z = self.down_sampler(hiddens)
|
||||||
z = self.activation(z)
|
z = self.activation(z)
|
||||||
|
@ -154,6 +157,7 @@ class LowRankAdapterModel(DeltaBase):
|
||||||
config_class = LowRankAdapterConfig
|
config_class = LowRankAdapterConfig
|
||||||
delta_type = "low_rank_adapter"
|
delta_type = "low_rank_adapter"
|
||||||
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
||||||
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = True
|
_need_pseudo_data = True
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -166,6 +170,7 @@ class LowRankAdapterModel(DeltaBase):
|
||||||
unfrozen_modules: Optional[List[str]] = None,
|
unfrozen_modules: Optional[List[str]] = None,
|
||||||
common_structure: Optional[bool] = None,
|
common_structure: Optional[bool] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
|
backend: Optional[str] = 'hf',
|
||||||
):
|
):
|
||||||
DeltaBase.__init__(self,
|
DeltaBase.__init__(self,
|
||||||
backbone_model,
|
backbone_model,
|
||||||
|
@ -174,6 +179,7 @@ class LowRankAdapterModel(DeltaBase):
|
||||||
unfrozen_modules=unfrozen_modules,
|
unfrozen_modules=unfrozen_modules,
|
||||||
common_structure=common_structure,
|
common_structure=common_structure,
|
||||||
interactive_modify=interactive_modify,
|
interactive_modify=interactive_modify,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
arg_names = get_arg_names_inside_func(self.__init__)
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
for arg_name in arg_names:
|
for arg_name in arg_names:
|
||||||
|
@ -209,6 +215,6 @@ class LowRankAdapterModel(DeltaBase):
|
||||||
non_linearity = self.non_linearity,
|
non_linearity = self.non_linearity,
|
||||||
low_rank_w_init = self.low_rank_w_init,
|
low_rank_w_init = self.low_rank_w_init,
|
||||||
low_rank_rank = self.low_rank_rank,
|
low_rank_rank = self.low_rank_rank,
|
||||||
device=module_device)
|
device=module_device, backend=self.backend)
|
||||||
self.delta_modules.append(adapterlayer)
|
self.delta_modules.append(adapterlayer)
|
||||||
return adapterlayer
|
return adapterlayer
|
||||||
|
|
|
@ -25,30 +25,36 @@ class ParallelAdapterLayer(nn.Module):
|
||||||
def get_layer_count(cls):
|
def get_layer_count(cls):
|
||||||
return cls.layer_count
|
return cls.layer_count
|
||||||
|
|
||||||
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None):
|
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', scaled=1, device=None, backend='hf'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bottleneck_dim = bottleneck_dim
|
self.bottleneck_dim = bottleneck_dim
|
||||||
self.device = device
|
self.device = device
|
||||||
self.instantiated = False
|
self.instantiated = False
|
||||||
self.non_linearity = non_linearity
|
self.non_linearity = non_linearity
|
||||||
self.scaled = scaled
|
self.scaled = scaled
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
self.layer_id = ParallelAdapterLayer.get_layer_count()
|
self.layer_id = ParallelAdapterLayer.get_layer_count()
|
||||||
ParallelAdapterLayer.count_layer()
|
ParallelAdapterLayer.count_layer()
|
||||||
|
|
||||||
|
|
||||||
def instantiate(self, hidden_dim):
|
def instantiate(self, hiddens):
|
||||||
|
self.hidden_dim = hiddens.shape[-1]
|
||||||
|
self.hidden_dtype = hiddens.dtype
|
||||||
self.modulelist = nn.Sequential()
|
self.modulelist = nn.Sequential()
|
||||||
self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.device))
|
self.modulelist.add_module("down_proj",nn.Linear(self.hidden_dim, self.bottleneck_dim, device=self.device, dtype=self.hidden_dtype))
|
||||||
|
|
||||||
# select non-linearity
|
# select non-linearity
|
||||||
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
|
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
|
||||||
|
|
||||||
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
|
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device, dtype=self.hidden_dtype))
|
||||||
|
|
||||||
self.instantiated = True
|
self.instantiated = True
|
||||||
# initialize the weight, which is important for fast convergence and better performance.
|
# initialize the weight, which is important for fast convergence and better performance.
|
||||||
self.apply(self._init_weight)
|
self.apply(self._init_weight)
|
||||||
|
if self.backend == 'bmt':
|
||||||
|
import bmtrain as bmt
|
||||||
|
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
|
||||||
|
|
||||||
def _init_weight(self, module):
|
def _init_weight(self, module):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
|
@ -71,9 +77,8 @@ class ParallelAdapterLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
if not self.instantiated:
|
if not self.instantiated:
|
||||||
self.hidden_dim = hiddens.shape[-1]
|
# logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
||||||
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
self.instantiate(hiddens = hiddens)
|
||||||
self.instantiate(hidden_dim=self.hidden_dim)
|
|
||||||
|
|
||||||
|
|
||||||
self.adapter_output = self.modulelist(hiddens) * self.scaled
|
self.adapter_output = self.modulelist(hiddens) * self.scaled
|
||||||
|
@ -141,12 +146,14 @@ class ParallelAdapterModel(DeltaBase):
|
||||||
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
|
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
|
||||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
|
||||||
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
||||||
|
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_class = ParallelAdapterConfig
|
config_class = ParallelAdapterConfig
|
||||||
delta_type = "parallel_adapter"
|
delta_type = "parallel_adapter"
|
||||||
default_modified_modules = ["attn@", "attn@", "ff@.w1@", "ff@.w2@"]
|
default_modified_modules = ["attn@", "attn@", "ff@.w1@", "ff@.w2@"]
|
||||||
# default_modified_modules = ["attn", "attn", "ff.w1", "ff.w2"]
|
# default_modified_modules = ["attn", "attn", "ff.w1", "ff.w2"]
|
||||||
|
_supported_backends = ['hf', 'bmt']
|
||||||
_need_pseudo_data = True
|
_need_pseudo_data = True
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -156,7 +163,8 @@ class ParallelAdapterModel(DeltaBase):
|
||||||
exclude_modules: Optional[List[str]] = None,
|
exclude_modules: Optional[List[str]] = None,
|
||||||
unfrozen_modules: Optional[bool] = None,
|
unfrozen_modules: Optional[bool] = None,
|
||||||
common_structure: Optional[bool] = None,
|
common_structure: Optional[bool] = None,
|
||||||
interactive_modify: Optional[Union[bool, int]] = False,
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
|
backend: Optional[str] = "hf",
|
||||||
):
|
):
|
||||||
DeltaBase.__init__(self,
|
DeltaBase.__init__(self,
|
||||||
backbone_model,
|
backbone_model,
|
||||||
|
@ -165,6 +173,7 @@ class ParallelAdapterModel(DeltaBase):
|
||||||
unfrozen_modules=unfrozen_modules,
|
unfrozen_modules=unfrozen_modules,
|
||||||
common_structure=common_structure,
|
common_structure=common_structure,
|
||||||
interactive_modify=interactive_modify,
|
interactive_modify=interactive_modify,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
arg_names = get_arg_names_inside_func(self.__init__)
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
for arg_name in arg_names:
|
for arg_name in arg_names:
|
||||||
|
@ -193,7 +202,7 @@ class ParallelAdapterModel(DeltaBase):
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
|
adapterlayer = ParallelAdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device, backend=self.backend)
|
||||||
self.delta_modules.append(adapterlayer)
|
self.delta_modules.append(adapterlayer)
|
||||||
return adapterlayer
|
return adapterlayer
|
||||||
|
|
|
@ -516,6 +516,7 @@ class PrefixModel(DeltaBase):
|
||||||
config_class = PrefixConfig
|
config_class = PrefixConfig
|
||||||
delta_type = "prefix"
|
delta_type = "prefix"
|
||||||
default_modified_modules = ['attn@']
|
default_modified_modules = ['attn@']
|
||||||
|
_supported_backends = ['hf']
|
||||||
_need_pseudo_data = True
|
_need_pseudo_data = True
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
|
|
@ -161,6 +161,7 @@ class SoftPromptModel(DeltaBase):
|
||||||
config_class = SoftPromptConfig
|
config_class = SoftPromptConfig
|
||||||
delta_type = "soft_prompt"
|
delta_type = "soft_prompt"
|
||||||
default_modified_modules = ["root"] # not used
|
default_modified_modules = ["root"] # not used
|
||||||
|
_supported_backends = ['hf'] #'bmt']
|
||||||
_need_pseudo_data = False
|
_need_pseudo_data = False
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
class BackendMapping:
|
||||||
|
"""
|
||||||
|
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
- config_mapping: The map model type to config class
|
||||||
|
- model_mapping: The map model type to model (or tokenizer) class
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backend):
|
||||||
|
self.backend = backend
|
||||||
|
assert backend in ['hf', 'bmt'], "Backend should be one of 'hf', 'bmt'. "
|
||||||
|
if backend == 'hf':
|
||||||
|
self.backend_mapping = {
|
||||||
|
"linear": "torch.nn.Linear",
|
||||||
|
"layer_norm": "torch.nn.LayerNorm",
|
||||||
|
"module": "torch.nn.Module",
|
||||||
|
"parameter": "torch.nn.Parameter"
|
||||||
|
}
|
||||||
|
elif backend == 'bmt':
|
||||||
|
self.backend_mapping = {
|
||||||
|
"linear": "model_center.layer.Linear",
|
||||||
|
"layer_norm": "model_center.layer.LayerNorm",
|
||||||
|
"module": "bmtrain.layer.DistributedModule",
|
||||||
|
"parameter": "bmtrain.nn.DistributedParameter"
|
||||||
|
}
|
||||||
|
self.registered = {}
|
||||||
|
|
||||||
|
def load(self, model_type):
|
||||||
|
if model_type not in self.registered:
|
||||||
|
splited = self.backend_mapping[model_type].split(".")
|
||||||
|
module_name, class_name = ".".join(splited[:-1]), splited[-1]
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
the_class = getattr(module, class_name)
|
||||||
|
self.registered[model_type] = the_class
|
||||||
|
return self.registered[model_type]
|
||||||
|
|
||||||
|
def check_type(self, module, expect_type):
|
||||||
|
the_class = self.load(expect_type)
|
||||||
|
if isinstance(module, the_class):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# def keys(self):
|
||||||
|
# mapping_keys = [
|
||||||
|
# self._load_attr_from_module(key, name)
|
||||||
|
# for key, name in self._config_mapping.items()
|
||||||
|
# if key in self._model_mapping.keys()
|
||||||
|
# ]
|
||||||
|
# return mapping_keys + list(self._extra_content.keys())
|
||||||
|
|
||||||
|
# def get(self, key, default):
|
||||||
|
# try:
|
||||||
|
# return self.__getitem__(key)
|
||||||
|
# except KeyError:
|
||||||
|
# return default
|
||||||
|
|
||||||
|
# def __bool__(self):
|
||||||
|
# return bool(self.keys())
|
||||||
|
|
||||||
|
# def values(self):
|
||||||
|
# mapping_values = [
|
||||||
|
# self._load_attr_from_module(key, name)
|
||||||
|
# for key, name in self._model_mapping.items()
|
||||||
|
# if key in self._config_mapping.keys()
|
||||||
|
# ]
|
||||||
|
# return mapping_values + list(self._extra_content.values())
|
||||||
|
|
||||||
|
# def items(self):
|
||||||
|
# mapping_items = [
|
||||||
|
# (
|
||||||
|
# self._load_attr_from_module(key, self._config_mapping[key]),
|
||||||
|
# self._load_attr_from_module(key, self._model_mapping[key]),
|
||||||
|
# )
|
||||||
|
# for key in self._model_mapping.keys()
|
||||||
|
# if key in self._config_mapping.keys()
|
||||||
|
# ]
|
||||||
|
# return mapping_items + list(self._extra_content.items())
|
||||||
|
|
||||||
|
# def __iter__(self):
|
||||||
|
# return iter(self.keys())
|
||||||
|
|
||||||
|
# def __contains__(self, item):
|
||||||
|
# if item in self._extra_content:
|
||||||
|
# return True
|
||||||
|
# if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
||||||
|
# return False
|
||||||
|
# model_type = self._reverse_config_mapping[item.__name__]
|
||||||
|
# return model_type in self._model_mapping
|
||||||
|
|
||||||
|
# def register(self, key, value):
|
||||||
|
# """
|
||||||
|
# Register a new model in this mapping.
|
||||||
|
# """
|
||||||
|
# if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
||||||
|
# model_type = self._reverse_config_mapping[key.__name__]
|
||||||
|
# if model_type in self._model_mapping.keys():
|
||||||
|
# raise ValueError(f"'{key}' is already used by a Transformers model.")
|
||||||
|
|
||||||
|
# self._extra_content[key] = value
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
import opendelta.utils.logging as logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def inspect_module_statistics(module: Optional[nn.Module]=None, verbose=True):
|
||||||
|
r"""Get the statistics of the parameters in the delta modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`dict`: The statistics of the parameters in the delta modules.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
stat = {}
|
||||||
|
n_trainable = num_trainable_parameters(module)
|
||||||
|
n_total = num_total_parameters(module)
|
||||||
|
|
||||||
|
stat['total_parameters'] = n_total
|
||||||
|
stat['trainable_parameters'] = n_trainable
|
||||||
|
|
||||||
|
stat['trainable_ratio'] = n_trainable/n_total
|
||||||
|
|
||||||
|
n_delta = num_delta_parameters(module)
|
||||||
|
n_total = num_total_parameters(module)
|
||||||
|
stat['delta_parameters'] = n_delta
|
||||||
|
stat['delta_ratio'] = n_delta/n_total
|
||||||
|
|
||||||
|
cudamem = 0
|
||||||
|
maxcudamem = 0
|
||||||
|
for device_id in range(torch.cuda.device_count()):
|
||||||
|
cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
|
||||||
|
maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
|
||||||
|
stat['cudamem'] = cudamem
|
||||||
|
stat['maxcudamem'] = maxcudamem
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(stat)
|
||||||
|
|
||||||
|
return stat
|
||||||
|
|
||||||
|
def num_trainable_parameters(module: Optional[nn.Module]=None):
|
||||||
|
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
|
||||||
|
compute the trainable rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[nn.Parameter]`
|
||||||
|
"""
|
||||||
|
pnum_tot = 0
|
||||||
|
for param in module.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
pnum_tot += param.numel()
|
||||||
|
return pnum_tot
|
||||||
|
|
||||||
|
|
||||||
|
def num_total_parameters(module: Optional[nn.Module]=None):
|
||||||
|
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
|
||||||
|
compute the trainable rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[nn.Parameter]`
|
||||||
|
"""
|
||||||
|
pnum_tot = 0
|
||||||
|
for param in module.parameters():
|
||||||
|
pnum_tot += param.numel()
|
||||||
|
return pnum_tot
|
||||||
|
|
||||||
|
def num_delta_parameters(module: Optional[nn.Module]=None):
|
||||||
|
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
|
||||||
|
compute the trainable rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[nn.Parameter]`
|
||||||
|
"""
|
||||||
|
pnum_tot = 0
|
||||||
|
for param in module.parameters():
|
||||||
|
if hasattr(param, "_is_delta"):
|
||||||
|
pnum_tot += param.numel()
|
||||||
|
return pnum_tot
|
||||||
|
|
||||||
|
def inspect_optimizer_statistics(optimizer, verbose=True):
|
||||||
|
stats = {}
|
||||||
|
for id, param_group in enumerate(optimizer.param_groups):
|
||||||
|
stat = {}
|
||||||
|
fine_grain_info = [(p.numel(), p.requires_grad) for p in param_group['params']]
|
||||||
|
stat['total_parameters'] = sum(n for n, r in fine_grain_info)
|
||||||
|
stat['trainable_parameters'] = sum(n for n, r in fine_grain_info if r)
|
||||||
|
stat['trainable_ratio'] = "{:.6f}%".format(stat['trainable_parameters']/stat['total_parameters']*100)
|
||||||
|
for key in param_group:
|
||||||
|
if key != 'params':
|
||||||
|
stat[key] = param_group[key]
|
||||||
|
stats[f'param_group_{id}'] = stat
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"optimizer info: {stats}")
|
||||||
|
|
||||||
|
return stat
|
2
setup.py
2
setup.py
|
@ -31,7 +31,7 @@ def get_requirements():
|
||||||
with open('README.md', 'r') as f:
|
with open('README.md', 'r') as f:
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name = 'opendelta',
|
name = 'opendelta',
|
||||||
version = "0.3.1",
|
version = "0.3.2",
|
||||||
description = "An open source framework for delta learning (parameter efficient learning).",
|
description = "An open source framework for delta learning (parameter efficient learning).",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|
Loading…
Reference in New Issue