Merge remote-tracking branch 'origin/refactor-lora' into main
This commit is contained in:
commit
7eea0cb94e
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "cola",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_matthews_correlation",
|
||||||
|
"learning_rate": 0.0004,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 80,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/cola",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 32,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "cola",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "cola",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"warmup_steps": 0,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0005,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "mnli",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_accuracy",
|
||||||
|
"learning_rate": 0.0005,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 30,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/mnli",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 16,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "mnli",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "mnli",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0004,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "mrpc",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_accuracy",
|
||||||
|
"learning_rate": 0.0004,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 30,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/mrpc",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 16,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "mrpc",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "mrpc",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0004,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "qnli",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_accuracy",
|
||||||
|
"learning_rate": 0.0004,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 25,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/qnli",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 32,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "qnli",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "qnli",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0005,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "qqp",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_accuracy",
|
||||||
|
"learning_rate": 0.0005,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 25,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/qqp",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 16,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "qqp",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "qqp",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "rte",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_accuracy",
|
||||||
|
"learning_rate": 0.0005,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 80,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/rte",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 32,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "rte",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "rte",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0005,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "sst2",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"metric_for_best_model": "eval_accuracy",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"learning_rate": 0.0005,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 60,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/sst2",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 16,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "sst2",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "sst2",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0004,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "stsb",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_pearson",
|
||||||
|
"learning_rate": 0.0004,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 40,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/stsb",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 16,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "stsb",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "stsb",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
{
|
||||||
|
"dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"delta_lr": 0.0005,
|
||||||
|
"delta_type": "lora",
|
||||||
|
"do_eval": true,
|
||||||
|
"do_test": true,
|
||||||
|
"do_train": true,
|
||||||
|
"eval_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"eval_dataset_name": "wnli",
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"greater_is_better": true,
|
||||||
|
"metric_for_best_model": "eval_pearson",
|
||||||
|
"learning_rate": 0.0003,
|
||||||
|
"load_best_model_at_end": true,
|
||||||
|
"lora_alpha": 8,
|
||||||
|
"lora_rank": 8,
|
||||||
|
"max_source_length": 512,
|
||||||
|
"model_name": "roberta",
|
||||||
|
"model_name_or_path": "roberta-base",
|
||||||
|
"non_linearity": "gelu_new",
|
||||||
|
"num_train_epochs": 30,
|
||||||
|
"output_dir": "outputs/lora/roberta-base/v2/wnli",
|
||||||
|
"per_device_eval_batch_size": 100,
|
||||||
|
"per_device_train_batch_size": 32,
|
||||||
|
"predict_with_generate": true,
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_total_limit": 1,
|
||||||
|
"split_validation_test": true,
|
||||||
|
"task_name": "wnli",
|
||||||
|
"test_dataset_config_name": [
|
||||||
|
"en"
|
||||||
|
],
|
||||||
|
"test_dataset_name": "wnli",
|
||||||
|
"tokenizer_name": "roberta-base",
|
||||||
|
"unfrozen_modules": [
|
||||||
|
"classifier",
|
||||||
|
"deltas"
|
||||||
|
],
|
||||||
|
"warmup_ratio": 0.06,
|
||||||
|
"warmup_steps": 0,
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
"overwrite_output_dir": true,
|
||||||
|
"push_to_hub": false
|
||||||
|
}
|
|
@ -603,6 +603,7 @@ def main():
|
||||||
item = label_list[item]
|
item = label_list[item]
|
||||||
writer.write(f"{index}\t{item}\n")
|
writer.write(f"{index}\t{item}\n")
|
||||||
|
|
||||||
|
# from IPython import embed; embed()
|
||||||
|
|
||||||
# kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
|
# kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
|
||||||
# if data_args.task_name is not None:
|
# if data_args.task_name is not None:
|
||||||
|
|
|
@ -9,6 +9,7 @@ Visualization(model).structure_graph()
|
||||||
from opendelta import LoraModel
|
from opendelta import LoraModel
|
||||||
import re
|
import re
|
||||||
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
|
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
|
||||||
|
# delta_model = LoraModel(backbone_model=model, modified_modules=['[r][0-5]\.output.dense'])
|
||||||
print("after modify")
|
print("after modify")
|
||||||
delta_model.log()
|
delta_model.log()
|
||||||
# This will visualize the backbone after modification and other information.
|
# This will visualize the backbone after modification and other information.
|
|
@ -479,7 +479,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def insert_sequential_module(self, module, delta_module=None, name='delta', strict=False, _delta_info=None):
|
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||||
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
|
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
|
||||||
function of the original module to firstly pass the arguments into the new module's forward function and then pass
|
function of the original module to firstly pass the arguments into the new module's forward function and then pass
|
||||||
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
|
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
|
||||||
|
@ -519,14 +519,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
|
|
||||||
_delta_info = {"method": "insert_sequential",
|
_delta_info = {"method": "insert_sequential",
|
||||||
"delta_module": delta_module,
|
"delta_module": delta_module,
|
||||||
"delta_name": name,
|
"delta_name": delta_name,
|
||||||
"delta_belong": self,
|
"delta_belong": self,
|
||||||
"state": "on"}
|
"state": "on"}
|
||||||
self._register_delta_infos(parent_module=module,
|
self._register_delta_infos(parent_module=module,
|
||||||
_delta_info = _delta_info)
|
_delta_info = _delta_info)
|
||||||
else:
|
else:
|
||||||
delta_module = _delta_info["delta_module"]
|
delta_module = _delta_info["delta_module"]
|
||||||
name = _delta_info["delta_name"]
|
delta_name = _delta_info["delta_name"]
|
||||||
|
|
||||||
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
||||||
|
|
||||||
|
@ -536,20 +536,59 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
# may have bugs when module.forward is nestedly wrapped.
|
# may have bugs when module.forward is nestedly wrapped.
|
||||||
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def insert_parrellel_module(self, module, pre_caller=None, post_caller=None, delta_module=None, name='delta'):
|
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||||
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
|
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
|
||||||
function of the original module to firstly pass the arguments into the delta model's forward function and set
|
function of the original module to firstly pass the arguments into the delta model's forward function and set
|
||||||
aside the calculation result. Then combine it with the calculation result output from the backbone module.
|
aside the calculation result. Then combine it with the calculation result output from the backbone module.
|
||||||
|
|
||||||
When implementing the new module , researchers should be aware of the arguments and keywards of the original module's forward function.
|
When implementing the new module , researchers should be aware of the arguments and keywards of the original module's forward function.
|
||||||
|
|
||||||
# TODO: currently not in use.
|
Args:
|
||||||
|
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
|
||||||
|
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
|
||||||
|
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
|
||||||
|
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
|
||||||
|
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
|
||||||
|
original delta is passed through ``_delta_info``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
|
args = args[1:] # the first argument here is ``self``
|
||||||
|
delta_module = getattr(org_module, delta_name)
|
||||||
|
ret_1 = _org_func(*args, **kwargs)
|
||||||
|
ret_2 = delta_module.forward(*args, **kwargs)
|
||||||
|
return ret_1 + ret_2
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
if hasattr(module.forward, "__wrapped__"):
|
||||||
|
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
|
||||||
|
|
||||||
|
# record info for plug and unplug and nested wrap
|
||||||
|
if _delta_info is None:
|
||||||
|
if delta_module is None:
|
||||||
|
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
|
||||||
|
|
||||||
|
_delta_info = {"method": "insert_parallel",
|
||||||
|
"delta_module": delta_module,
|
||||||
|
"delta_name": delta_name,
|
||||||
|
"delta_belong": self,
|
||||||
|
"state": "on"}
|
||||||
|
self._register_delta_infos(parent_module=module,
|
||||||
|
_delta_info = _delta_info)
|
||||||
|
else:
|
||||||
|
delta_module = _delta_info["delta_module"]
|
||||||
|
delta_name = _delta_info["delta_name"]
|
||||||
|
|
||||||
|
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
||||||
|
|
||||||
|
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
||||||
|
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
||||||
|
# for DataParallel's copy behavior. Experimental:
|
||||||
|
# may have bugs when module.forward is nestedly wrapped.
|
||||||
|
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
||||||
|
|
||||||
|
|
||||||
def set_active_state_dict(self, module: nn.Module):
|
def set_active_state_dict(self, module: nn.Module):
|
||||||
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.
|
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.
|
||||||
|
|
|
@ -192,7 +192,7 @@ class AdapterModel(DeltaBase):
|
||||||
def update_module(self, module: nn.Module, key: str):
|
def update_module(self, module: nn.Module, key: str):
|
||||||
_, _, ref = self.find_module(module, key)
|
_, _, ref = self.find_module(module, key)
|
||||||
adapterlayer = self.new_module_like(ref)
|
adapterlayer = self.new_module_like(ref)
|
||||||
self.insert_sequential_module(ref, delta_module=adapterlayer, name="adapter")
|
self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="adapter")
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
|
|
|
@ -179,7 +179,7 @@ class BitFitModel(DeltaBase):
|
||||||
|
|
||||||
def add_bias_to_others(self, c):
|
def add_bias_to_others(self, c):
|
||||||
new_bias = BiasLayer()
|
new_bias = BiasLayer()
|
||||||
self.insert_sequential_module(c, delta_module=new_bias, name="bitfit") # name shouldn't be `bias` here, since
|
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since
|
||||||
# the name `bias` is reserved for some module such as roberta's LayerNorm.
|
# the name `bias` is reserved for some module such as roberta's LayerNorm.
|
||||||
self.delta_modules.append(new_bias)
|
self.delta_modules.append(new_bias)
|
||||||
|
|
||||||
|
|
|
@ -277,7 +277,7 @@ class CompacterModel(DeltaBase):
|
||||||
adapterlayer = self.new_module_like(ref)
|
adapterlayer = self.new_module_like(ref)
|
||||||
self.insert_sequential_module(ref,
|
self.insert_sequential_module(ref,
|
||||||
delta_module=adapterlayer,
|
delta_module=adapterlayer,
|
||||||
name="compactor")
|
delta_name="compactor")
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from turtle import forward
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||||
|
@ -7,6 +8,42 @@ from transformers.models.t5 import T5ForConditionalGeneration
|
||||||
import loralib as lora
|
import loralib as lora
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from opendelta import BaseDeltaConfig
|
from opendelta import BaseDeltaConfig
|
||||||
|
import math
|
||||||
|
|
||||||
|
class LowRankLinear(nn.Module):
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
# copy from loralib and do some refactor
|
||||||
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
weight,
|
||||||
|
r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.r = r
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
self.lora_dropout = lora_dropout
|
||||||
|
self.lin = nn.Linear(in_features, out_features) #
|
||||||
|
if lora_dropout > 0.:
|
||||||
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||||
|
else:
|
||||||
|
self.lora_dropout = lambda x: x
|
||||||
|
if r > 0:
|
||||||
|
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
||||||
|
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
||||||
|
self.scaling = self.lora_alpha / self.r
|
||||||
|
self.lin.reset_parameters() #
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseDeltaConfig):
|
class LoraConfig(BaseDeltaConfig):
|
||||||
r"""
|
r"""
|
||||||
|
@ -27,7 +64,6 @@ class LoraConfig(BaseDeltaConfig):
|
||||||
setattr(self, arg_name, locals()[arg_name])
|
setattr(self, arg_name, locals()[arg_name])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LoraModel(DeltaBase):
|
class LoraModel(DeltaBase):
|
||||||
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
|
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
|
||||||
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_, we use loralib.linear
|
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_, we use loralib.linear
|
||||||
|
@ -89,11 +125,10 @@ class LoraModel(DeltaBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def update_module(self, module: nn.Module, key: str):
|
def update_module(self, module: nn.Module, key: str):
|
||||||
parent_ref, child_name, child_ref = self.find_module(module, key)
|
parent_ref, child_name, child_ref = self.find_module(module, key)
|
||||||
new_module = self.new_module_like(child_module=child_ref)
|
parallel_module = self.new_module_like(child_module=child_ref)
|
||||||
self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora")
|
self.insert_parallel_module(child_ref, delta_module=parallel_module, delta_name="lora")
|
||||||
|
|
||||||
def _pseudo_data_to_instantiate(self, module):
|
def _pseudo_data_to_instantiate(self, module):
|
||||||
# no need to pass pseudo input, so overwrite it
|
# no need to pass pseudo input, so overwrite it
|
||||||
|
@ -102,26 +137,13 @@ class LoraModel(DeltaBase):
|
||||||
def new_module_like(self, child_module):
|
def new_module_like(self, child_module):
|
||||||
if isinstance(child_module, nn.Linear):
|
if isinstance(child_module, nn.Linear):
|
||||||
in_features, out_features = child_module.in_features, child_module.out_features
|
in_features, out_features = child_module.in_features, child_module.out_features
|
||||||
new_module = lora.Linear(in_features=in_features,
|
new_module = LowRankLinear(in_features = in_features,
|
||||||
out_features=out_features,
|
out_features = out_features,
|
||||||
|
weight = child_module.weight,
|
||||||
r=self.lora_r,
|
r=self.lora_r,
|
||||||
lora_alpha=self.lora_alpha,
|
lora_alpha=self.lora_alpha,
|
||||||
lora_dropout=self.lora_dropout)
|
lora_dropout=self.lora_dropout)
|
||||||
new_module.weight = child_module.weight
|
self.delta_modules.append(new_module)
|
||||||
new_module.bias = child_module.bias # if bias is None, also copy
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return new_module
|
return new_module
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def mark_as_delta(self, module: nn.Module = None):
|
|
||||||
if module is None:
|
|
||||||
module=self
|
|
||||||
for n, p in module.named_parameters():
|
|
||||||
param_name = n.split(".")[-1]
|
|
||||||
if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter.
|
|
||||||
setattr(p, "_is_delta", True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||||
|
from opendelta.utils.name_based_addressing import *
|
||||||
|
from opendelta.basemodel import DeltaBase
|
||||||
|
from transformers.models.t5 import T5ForConditionalGeneration
|
||||||
|
import loralib as lora
|
||||||
|
import torch.nn as nn
|
||||||
|
from opendelta import BaseDeltaConfig
|
||||||
|
|
||||||
|
class LoraConfig(BaseDeltaConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a :py:class:`~LoraModel`
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lora_r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
|
for arg_name in arg_names:
|
||||||
|
if not hasattr(self, arg_name): # the arg has not been registered in parent config
|
||||||
|
setattr(self, arg_name, locals()[arg_name])
|
||||||
|
|
||||||
|
|
||||||
|
class LoraModel(DeltaBase):
|
||||||
|
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
|
||||||
|
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_, we use loralib.linear
|
||||||
|
to replace the linear layer of the backbone model.
|
||||||
|
|
||||||
|
class attributes:
|
||||||
|
- default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the
|
||||||
|
attention layer. However, other linears can also be modified, and may lead to better performance.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
modified_modules should point to linear layer. We currently don't support broadcast to all linears in
|
||||||
|
a module's child modules.
|
||||||
|
|
||||||
|
- delta_type = "lora"
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
||||||
|
lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.
|
||||||
|
lora_alpha (:obj:`bool`, *optional*): A hyper-parameter to control the init scale of loralib.linear .
|
||||||
|
lora_dropout (:obj:`bool`, *optional*): The dropout rate in lora.linear.
|
||||||
|
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
||||||
|
the implemented ones)
|
||||||
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
||||||
|
together with the prefix parameters.
|
||||||
|
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LoraConfig
|
||||||
|
delta_type = "lora"
|
||||||
|
default_modified_modules = ['attn.q', 'attn.v']
|
||||||
|
def __init__(self,
|
||||||
|
backbone_model: nn.Module,
|
||||||
|
lora_r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.0,
|
||||||
|
modified_modules: Optional[bool] = None,
|
||||||
|
unfrozen_modules: Optional[bool] = None,
|
||||||
|
common_structure: Optional[bool] = None,
|
||||||
|
interactive_modify: Optional[Union[bool, int]] = False,
|
||||||
|
):
|
||||||
|
DeltaBase.__init__(self,
|
||||||
|
backbone_model,
|
||||||
|
modified_modules=modified_modules,
|
||||||
|
unfrozen_modules=unfrozen_modules,
|
||||||
|
common_structure=common_structure,
|
||||||
|
interactive_modify=interactive_modify,
|
||||||
|
)
|
||||||
|
arg_names = get_arg_names_inside_func(self.__init__)
|
||||||
|
for arg_name in arg_names:
|
||||||
|
if not hasattr(self, arg_name): # not registered in parent class
|
||||||
|
setattr(self, arg_name, locals()[arg_name])
|
||||||
|
|
||||||
|
self.delta_modules = nn.ModuleList()
|
||||||
|
|
||||||
|
self.add_all_delta_to_backbone(self.backbone_model,
|
||||||
|
self.modified_modules,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def update_module(self, module: nn.Module, key: str):
|
||||||
|
parent_ref, child_name, child_ref = self.find_module(module, key)
|
||||||
|
new_module = self.new_module_like(child_module=child_ref)
|
||||||
|
self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora")
|
||||||
|
|
||||||
|
def _pseudo_data_to_instantiate(self, module):
|
||||||
|
# no need to pass pseudo input, so overwrite it
|
||||||
|
pass
|
||||||
|
|
||||||
|
def new_module_like(self, child_module):
|
||||||
|
if isinstance(child_module, nn.Linear):
|
||||||
|
in_features, out_features = child_module.in_features, child_module.out_features
|
||||||
|
new_module = lora.Linear(in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
r=self.lora_r,
|
||||||
|
lora_alpha=self.lora_alpha,
|
||||||
|
lora_dropout=self.lora_dropout)
|
||||||
|
new_module.weight = child_module.weight
|
||||||
|
new_module.bias = child_module.bias # if bias is None, also copy
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return new_module
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def mark_as_delta(self, module: nn.Module = None):
|
||||||
|
if module is None:
|
||||||
|
module=self
|
||||||
|
for n, p in module.named_parameters():
|
||||||
|
param_name = n.split(".")[-1]
|
||||||
|
if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter.
|
||||||
|
setattr(p, "_is_delta", True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -194,7 +194,7 @@ class LowRankAdapterModel(DeltaBase):
|
||||||
def update_module(self, module: nn.Module, key: str):
|
def update_module(self, module: nn.Module, key: str):
|
||||||
_, _, ref = self.find_module(module, key)
|
_, _, ref = self.find_module(module, key)
|
||||||
adapterlayer = self.new_module_like(ref)
|
adapterlayer = self.new_module_like(ref)
|
||||||
self.insert_sequential_module(ref, delta_module=adapterlayer, name="low_rank_adapter")
|
self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="low_rank_adapter")
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
|
|
|
@ -512,7 +512,7 @@ class PrefixModel(DeltaBase):
|
||||||
module_list=self.delta_modules)
|
module_list=self.delta_modules)
|
||||||
self.delta_modules = None
|
self.delta_modules = None
|
||||||
self.reparams = reparams
|
self.reparams = reparams
|
||||||
self.insert_sequential_module(first_modified_module, delta_module=reparams, name="reparams", strict=False)
|
self.insert_sequential_module(first_modified_module, delta_module=reparams, delta_name="reparams", strict=False)
|
||||||
self.mark_as_delta()
|
self.mark_as_delta()
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
@ -522,7 +522,7 @@ class PrefixModel(DeltaBase):
|
||||||
_, _, ref = self.find_module(module, key)
|
_, _, ref = self.find_module(module, key)
|
||||||
|
|
||||||
prefixlayer, ref = self.new_module_like(ref)
|
prefixlayer, ref = self.new_module_like(ref)
|
||||||
self.insert_sequential_module(ref, delta_module=prefixlayer, name="prefix")
|
self.insert_sequential_module(ref, delta_module=prefixlayer, delta_name="prefix")
|
||||||
self.delta_modules.append(prefixlayer)
|
self.delta_modules.append(prefixlayer)
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
|
|
|
@ -193,11 +193,11 @@ class SoftPromptModel(DeltaBase):
|
||||||
soft_prompt_layer = self.new_module_like(self.raw_embedding)
|
soft_prompt_layer = self.new_module_like(self.raw_embedding)
|
||||||
self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,
|
self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,
|
||||||
delta_module=soft_prompt_layer,
|
delta_module=soft_prompt_layer,
|
||||||
name="soft_prompt_layer" )
|
delta_name="soft_prompt_layer" )
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
soft_prompt_layer = SoftPromptLayer(
|
soft_prompt_layer = SoftPromptLayer(
|
||||||
soft_token_num = self.soft_token_num,
|
soft_token_num = self.soft_token_num,
|
||||||
raw_embedding = self.raw_embedding,
|
raw_embedding = self.raw_embedding,
|
||||||
token_init = self.token_init,
|
token_init = self.token_init,
|
||||||
|
|
|
@ -8,7 +8,7 @@ def new_replicate_for_data_parallel(self):
|
||||||
r""" self is the parent module.
|
r""" self is the parent module.
|
||||||
"""
|
"""
|
||||||
# rewrite the replicate in DataParallel.
|
# rewrite the replicate in DataParallel.
|
||||||
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
|
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
args = args[1:] # the first argument here is ``self``
|
args = args[1:] # the first argument here is ``self``
|
||||||
delta_module = getattr(org_module, delta_name)
|
delta_module = getattr(org_module, delta_name)
|
||||||
if hasattr(delta_module, "pre_forward"):
|
if hasattr(delta_module, "pre_forward"):
|
||||||
|
@ -17,6 +17,13 @@ def new_replicate_for_data_parallel(self):
|
||||||
if hasattr(delta_module, "post_forward"):
|
if hasattr(delta_module, "post_forward"):
|
||||||
ret = delta_module.post_forward(ret)
|
ret = delta_module.post_forward(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
|
args = args[1:] # the first argument here is ``self``
|
||||||
|
delta_module = getattr(org_module, delta_name)
|
||||||
|
ret_1 = _org_func(*args, **kwargs)
|
||||||
|
ret_2 = delta_module.forward(*args, **kwargs)
|
||||||
|
return ret_1 + ret_2
|
||||||
replica = self.__new__(type(self))
|
replica = self.__new__(type(self))
|
||||||
org_forward = replica.forward
|
org_forward = replica.forward
|
||||||
replica.__dict__ = self.__dict__.copy()
|
replica.__dict__ = self.__dict__.copy()
|
||||||
|
@ -25,8 +32,13 @@ def new_replicate_for_data_parallel(self):
|
||||||
|
|
||||||
|
|
||||||
for _delta_info in self._delta_infos:
|
for _delta_info in self._delta_infos:
|
||||||
if _delta_info['method'] == "insert_sequential" and _delta_info['state'] == "on":
|
if _delta_info['state'] == 'on':
|
||||||
new_forward = decorate(replica.forward, _caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
if _delta_info['method'] == "insert_sequential":
|
||||||
|
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
||||||
|
elif _delta_info['method'] == "insert_parallel":
|
||||||
|
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
|
||||||
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))
|
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))
|
||||||
|
|
||||||
# replicas do not have parameters themselves, the replicas reference the original
|
# replicas do not have parameters themselves, the replicas reference the original
|
||||||
|
|
Loading…
Reference in New Issue