From c2e086c6ed2572c16178afee91e9d993b196e827 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sun, 20 Feb 2022 17:23:31 +0800 Subject: [PATCH 1/2] lora --- .../configs/lora_roberta-base/lora_cola.json | 47 +++++++ .../configs/lora_roberta-base/lora_mnli.json | 46 +++++++ .../configs/lora_roberta-base/lora_mrpc.json | 47 +++++++ .../configs/lora_roberta-base/lora_qnli.json | 47 +++++++ .../configs/lora_roberta-base/lora_qqp.json | 47 +++++++ .../configs/lora_roberta-base/lora_rte.json | 46 +++++++ .../configs/lora_roberta-base/lora_sst2.json | 47 +++++++ .../configs/lora_roberta-base/lora_stsb.json | 47 +++++++ .../configs/lora_roberta-base/lora_wnli.json | 48 +++++++ .../examples_text-classification/run_glue.py | 1 + examples/tutorial/0_regex.py | 1 + opendelta/basemodel.py | 57 ++++++-- opendelta/delta_models/adapter.py | 2 +- opendelta/delta_models/bitfit.py | 2 +- opendelta/delta_models/compacter.py | 2 +- opendelta/delta_models/lora.py | 66 ++++++--- opendelta/delta_models/lora_old.py | 126 ++++++++++++++++++ opendelta/delta_models/low_rank_adapter.py | 2 +- opendelta/delta_models/prefix.py | 4 +- opendelta/delta_models/soft_prompt.py | 4 +- opendelta/utils/data_parallel.py | 18 ++- 21 files changed, 665 insertions(+), 42 deletions(-) create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_cola.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_mnli.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_qnli.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_qqp.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_rte.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_sst2.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_stsb.json create mode 100644 examples/examples_text-classification/configs/lora_roberta-base/lora_wnli.json create mode 100644 opendelta/delta_models/lora_old.py diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_cola.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_cola.json new file mode 100644 index 0000000..aa05f0a --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_cola.json @@ -0,0 +1,47 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "cola", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_matthews_correlation", + "learning_rate": 0.0004, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 80, + "output_dir": "outputs/lora/roberta-base/v2/cola", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 32, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "cola", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "cola", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "warmup_steps": 0, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_mnli.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_mnli.json new file mode 100644 index 0000000..06d4428 --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_mnli.json @@ -0,0 +1,46 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0005, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "mnli", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_accuracy", + "learning_rate": 0.0005, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 30, + "output_dir": "outputs/lora/roberta-base/v2/mnli", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 16, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "mnli", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "mnli", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json new file mode 100644 index 0000000..46afef5 --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json @@ -0,0 +1,47 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0004, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "mrpc", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_accuracy", + "learning_rate": 0.0004, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 30, + "output_dir": "outputs/lora/roberta-base/v2/mrpc", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 16, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "mrpc", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "mrpc", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_qnli.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_qnli.json new file mode 100644 index 0000000..05d28ce --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_qnli.json @@ -0,0 +1,47 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0004, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "qnli", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_accuracy", + "learning_rate": 0.0004, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 25, + "output_dir": "outputs/lora/roberta-base/v2/qnli", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 32, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "qnli", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "qnli", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_qqp.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_qqp.json new file mode 100644 index 0000000..0ca93ec --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_qqp.json @@ -0,0 +1,47 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0005, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "qqp", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_accuracy", + "learning_rate": 0.0005, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 25, + "output_dir": "outputs/lora/roberta-base/v2/qqp", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 16, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "qqp", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "qqp", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_rte.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_rte.json new file mode 100644 index 0000000..20f98d0 --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_rte.json @@ -0,0 +1,46 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "rte", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_accuracy", + "learning_rate": 0.0005, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 80, + "output_dir": "outputs/lora/roberta-base/rte", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 32, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "rte", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "rte", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_sst2.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_sst2.json new file mode 100644 index 0000000..767d501 --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_sst2.json @@ -0,0 +1,47 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0005, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "sst2", + "evaluation_strategy": "epoch", + "metric_for_best_model": "eval_accuracy", + "greater_is_better": true, + "learning_rate": 0.0005, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 60, + "output_dir": "outputs/lora/roberta-base/v2/sst2", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 16, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "sst2", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "sst2", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_stsb.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_stsb.json new file mode 100644 index 0000000..827b139 --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_stsb.json @@ -0,0 +1,47 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0004, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "stsb", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_pearson", + "learning_rate": 0.0004, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 40, + "output_dir": "outputs/lora/roberta-base/v2/stsb", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 16, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "stsb", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "stsb", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_wnli.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_wnli.json new file mode 100644 index 0000000..941cddb --- /dev/null +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_wnli.json @@ -0,0 +1,48 @@ +{ + "dataset_config_name": [ + "en" + ], + "delta_lr": 0.0005, + "delta_type": "lora", + "do_eval": true, + "do_test": true, + "do_train": true, + "eval_dataset_config_name": [ + "en" + ], + "eval_dataset_name": "wnli", + "evaluation_strategy": "epoch", + "greater_is_better": true, + "metric_for_best_model": "eval_pearson", + "learning_rate": 0.0003, + "load_best_model_at_end": true, + "lora_alpha": 8, + "lora_rank": 8, + "max_source_length": 512, + "model_name": "roberta", + "model_name_or_path": "roberta-base", + "non_linearity": "gelu_new", + "num_train_epochs": 30, + "output_dir": "outputs/lora/roberta-base/v2/wnli", + "per_device_eval_batch_size": 100, + "per_device_train_batch_size": 32, + "predict_with_generate": true, + "save_strategy": "epoch", + "save_total_limit": 1, + "split_validation_test": true, + "task_name": "wnli", + "test_dataset_config_name": [ + "en" + ], + "test_dataset_name": "wnli", + "tokenizer_name": "roberta-base", + "unfrozen_modules": [ + "classifier", + "deltas" + ], + "warmup_ratio": 0.06, + "warmup_steps": 0, + "weight_decay": 0.1, + "overwrite_output_dir": true, + "push_to_hub": false +} \ No newline at end of file diff --git a/examples/examples_text-classification/run_glue.py b/examples/examples_text-classification/run_glue.py index 9ca0477..4c5d95e 100755 --- a/examples/examples_text-classification/run_glue.py +++ b/examples/examples_text-classification/run_glue.py @@ -603,6 +603,7 @@ def main(): item = label_list[item] writer.write(f"{index}\t{item}\n") + # from IPython import embed; embed() # kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} # if data_args.task_name is not None: diff --git a/examples/tutorial/0_regex.py b/examples/tutorial/0_regex.py index 642b920..6d884d7 100644 --- a/examples/tutorial/0_regex.py +++ b/examples/tutorial/0_regex.py @@ -9,6 +9,7 @@ Visualization(model).structure_graph() from opendelta import LoraModel import re delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense']) +# delta_model = LoraModel(backbone_model=model, modified_modules=['[r][0-5]\.output.dense']) print("after modify") delta_model.log() # This will visualize the backbone after modification and other information. \ No newline at end of file diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index a10ca16..ac21334 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -480,7 +480,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): """ raise NotImplementedError - def insert_sequential_module(self, module, delta_module=None, name='delta', strict=False, _delta_info=None): + def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward function of the original module to firstly pass the arguments into the new module's forward function and then pass it into the original ones. The new module can also be inserted after the original module with similar mechanism. @@ -520,14 +520,14 @@ class DeltaBase(nn.Module, SaveLoadMixin): _delta_info = {"method": "insert_sequential", "delta_module": delta_module, - "delta_name": name, + "delta_name": delta_name, "delta_belong": self, "state": "on"} self._register_delta_infos(parent_module=module, _delta_info = _delta_info) else: delta_module = _delta_info["delta_module"] - name = _delta_info["delta_name"] + delta_name = _delta_info["delta_name"] setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) @@ -537,20 +537,59 @@ class DeltaBase(nn.Module, SaveLoadMixin): # may have bugs when module.forward is nestedly wrapped. module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) - - - - def insert_parrellel_module(self, module, pre_caller=None, post_caller=None, delta_module=None, name='delta'): + def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): """insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward function of the original module to firstly pass the arguments into the delta model's forward function and set aside the calculation result. Then combine it with the calculation result output from the backbone module. When implementing the new module , researchers should be aware of the arguments and keywards of the original module's forward function. - # TODO: currently not in use. + Args: + module: (:obj:`nn.Module`): The (sub)module to inserted a delta module. + delta_module: (:obj:`DeltaBase`): The delta module to be inserted. + name: (:obj:`str`, *optional*): The name of the delta in the backbone module. + strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module. + _delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of + original delta is passed through ``_delta_info``. + """ - raise NotImplementedError + + def _caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + ret_1 = _org_func(*args, **kwargs) + ret_2 = delta_module.forward(*args, **kwargs) + return ret_1 + ret_2 + + if strict: + if hasattr(module.forward, "__wrapped__"): + raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?") + + # record info for plug and unplug and nested wrap + if _delta_info is None: + if delta_module is None: + raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.") + + _delta_info = {"method": "insert_parallel", + "delta_module": delta_module, + "delta_name": delta_name, + "delta_belong": self, + "state": "on"} + self._register_delta_infos(parent_module=module, + _delta_info = _delta_info) + else: + delta_module = _delta_info["delta_module"] + delta_name = _delta_info["delta_name"] + + setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) + + new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). + module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method + # for DataParallel's copy behavior. Experimental: + # may have bugs when module.forward is nestedly wrapped. + module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) + def set_active_state_dict(self, module: nn.Module): r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part. diff --git a/opendelta/delta_models/adapter.py b/opendelta/delta_models/adapter.py index 3017f01..cf7822e 100644 --- a/opendelta/delta_models/adapter.py +++ b/opendelta/delta_models/adapter.py @@ -192,7 +192,7 @@ class AdapterModel(DeltaBase): def update_module(self, module: nn.Module, key: str): _, _, ref = self.find_module(module, key) adapterlayer = self.new_module_like(ref) - self.insert_sequential_module(ref, delta_module=adapterlayer, name="adapter") + self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="adapter") def new_module_like(self, module): module_device = get_device(module) diff --git a/opendelta/delta_models/bitfit.py b/opendelta/delta_models/bitfit.py index 9bdff02..9bce262 100644 --- a/opendelta/delta_models/bitfit.py +++ b/opendelta/delta_models/bitfit.py @@ -179,7 +179,7 @@ class BitFitModel(DeltaBase): def add_bias_to_others(self, c): new_bias = BiasLayer() - self.insert_sequential_module(c, delta_module=new_bias, name="bitfit") # name shouldn't be `bias` here, since + self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since # the name `bias` is reserved for some module such as roberta's LayerNorm. self.delta_modules.append(new_bias) diff --git a/opendelta/delta_models/compacter.py b/opendelta/delta_models/compacter.py index 72c287a..c0c57b5 100644 --- a/opendelta/delta_models/compacter.py +++ b/opendelta/delta_models/compacter.py @@ -277,7 +277,7 @@ class CompacterModel(DeltaBase): adapterlayer = self.new_module_like(ref) self.insert_sequential_module(ref, delta_module=adapterlayer, - name="compactor") + delta_name="compactor") def new_module_like(self, module): module_device = get_device(module) diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 05af87e..492fea6 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -1,3 +1,4 @@ +from turtle import forward from typing import Optional, Union from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func @@ -7,6 +8,42 @@ from transformers.models.t5 import T5ForConditionalGeneration import loralib as lora import torch.nn as nn from opendelta import BaseDeltaConfig +import math + +class LowRankLinear(nn.Module): + # ------------------------------------------------------------------------------------------ + # Copyright (c) Microsoft Corporation. All rights reserved. + # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. + # ------------------------------------------------------------------------------------------ + # copy from loralib and do some refactor + def __init__(self, + in_features, + out_features, + weight, + r=8, + lora_alpha=16, + lora_dropout=0.0, + ): + super().__init__() + self.r = r + self.lora_alpha = lora_alpha + self.lora_dropout = lora_dropout + self.lin = nn.Linear(in_features, out_features) # + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + if r > 0: + self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.lin.reset_parameters() # + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x): + return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling + class LoraConfig(BaseDeltaConfig): r""" @@ -27,7 +64,6 @@ class LoraConfig(BaseDeltaConfig): setattr(self, arg_name, locals()[arg_name]) - class LoraModel(DeltaBase): r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models `_ . Thanks for their `loralib `_, we use loralib.linear @@ -89,11 +125,10 @@ class LoraModel(DeltaBase): ) - def update_module(self, module: nn.Module, key: str): parent_ref, child_name, child_ref = self.find_module(module, key) - new_module = self.new_module_like(child_module=child_ref) - self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora") + parallel_module = self.new_module_like(child_module=child_ref) + self.insert_parallel_module(child_ref, delta_module=parallel_module, delta_name="lora") def _pseudo_data_to_instantiate(self, module): # no need to pass pseudo input, so overwrite it @@ -102,26 +137,13 @@ class LoraModel(DeltaBase): def new_module_like(self, child_module): if isinstance(child_module, nn.Linear): in_features, out_features = child_module.in_features, child_module.out_features - new_module = lora.Linear(in_features=in_features, - out_features=out_features, + new_module = LowRankLinear(in_features = in_features, + out_features = out_features, + weight = child_module.weight, r=self.lora_r, lora_alpha=self.lora_alpha, lora_dropout=self.lora_dropout) - new_module.weight = child_module.weight - new_module.bias = child_module.bias # if bias is None, also copy + self.delta_modules.append(new_module) else: raise NotImplementedError - return new_module - - - - def mark_as_delta(self, module: nn.Module = None): - if module is None: - module=self - for n, p in module.named_parameters(): - param_name = n.split(".")[-1] - if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter. - setattr(p, "_is_delta", True) - - - \ No newline at end of file + return new_module \ No newline at end of file diff --git a/opendelta/delta_models/lora_old.py b/opendelta/delta_models/lora_old.py new file mode 100644 index 0000000..d4954dc --- /dev/null +++ b/opendelta/delta_models/lora_old.py @@ -0,0 +1,126 @@ +from typing import Optional, Union + +from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func +from opendelta.utils.name_based_addressing import * +from opendelta.basemodel import DeltaBase +from transformers.models.t5 import T5ForConditionalGeneration +import loralib as lora +import torch.nn as nn +from opendelta import BaseDeltaConfig + +class LoraConfig(BaseDeltaConfig): + r""" + This is the configuration class to store the configuration of a :py:class:`~LoraModel` + + """ + def __init__( + self, + lora_r=8, + lora_alpha=16, + lora_dropout=0.0, + **kwargs + ): + super().__init__(**kwargs) + arg_names = get_arg_names_inside_func(self.__init__) + for arg_name in arg_names: + if not hasattr(self, arg_name): # the arg has not been registered in parent config + setattr(self, arg_name, locals()[arg_name]) + + +class LoraModel(DeltaBase): + r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models `_ . + Thanks for their `loralib `_, we use loralib.linear + to replace the linear layer of the backbone model. + + class attributes: + - default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the + attention layer. However, other linears can also be modified, and may lead to better performance. + + .. note:: + modified_modules should point to linear layer. We currently don't support broadcast to all linears in + a module's child modules. + + - delta_type = "lora" + + + Args: + backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified. + lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has. + lora_alpha (:obj:`bool`, *optional*): A hyper-parameter to control the init scale of loralib.linear . + lora_dropout (:obj:`bool`, *optional*): The dropout rate in lora.linear. + modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only + the implemented ones) + unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen + together with the prefix parameters. + common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping. + + """ + + config_class = LoraConfig + delta_type = "lora" + default_modified_modules = ['attn.q', 'attn.v'] + def __init__(self, + backbone_model: nn.Module, + lora_r=8, + lora_alpha=16, + lora_dropout=0.0, + modified_modules: Optional[bool] = None, + unfrozen_modules: Optional[bool] = None, + common_structure: Optional[bool] = None, + interactive_modify: Optional[Union[bool, int]] = False, + ): + DeltaBase.__init__(self, + backbone_model, + modified_modules=modified_modules, + unfrozen_modules=unfrozen_modules, + common_structure=common_structure, + interactive_modify=interactive_modify, + ) + arg_names = get_arg_names_inside_func(self.__init__) + for arg_name in arg_names: + if not hasattr(self, arg_name): # not registered in parent class + setattr(self, arg_name, locals()[arg_name]) + + self.delta_modules = nn.ModuleList() + + self.add_all_delta_to_backbone(self.backbone_model, + self.modified_modules, + ) + + + + def update_module(self, module: nn.Module, key: str): + parent_ref, child_name, child_ref = self.find_module(module, key) + new_module = self.new_module_like(child_module=child_ref) + self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora") + + def _pseudo_data_to_instantiate(self, module): + # no need to pass pseudo input, so overwrite it + pass + + def new_module_like(self, child_module): + if isinstance(child_module, nn.Linear): + in_features, out_features = child_module.in_features, child_module.out_features + new_module = lora.Linear(in_features=in_features, + out_features=out_features, + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout) + new_module.weight = child_module.weight + new_module.bias = child_module.bias # if bias is None, also copy + else: + raise NotImplementedError + return new_module + + + + def mark_as_delta(self, module: nn.Module = None): + if module is None: + module=self + for n, p in module.named_parameters(): + param_name = n.split(".")[-1] + if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter. + setattr(p, "_is_delta", True) + + + \ No newline at end of file diff --git a/opendelta/delta_models/low_rank_adapter.py b/opendelta/delta_models/low_rank_adapter.py index b02fdb9..2d378e3 100644 --- a/opendelta/delta_models/low_rank_adapter.py +++ b/opendelta/delta_models/low_rank_adapter.py @@ -194,7 +194,7 @@ class LowRankAdapterModel(DeltaBase): def update_module(self, module: nn.Module, key: str): _, _, ref = self.find_module(module, key) adapterlayer = self.new_module_like(ref) - self.insert_sequential_module(ref, delta_module=adapterlayer, name="low_rank_adapter") + self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="low_rank_adapter") def new_module_like(self, module): module_device = get_device(module) diff --git a/opendelta/delta_models/prefix.py b/opendelta/delta_models/prefix.py index debfd8e..e7b3ebf 100644 --- a/opendelta/delta_models/prefix.py +++ b/opendelta/delta_models/prefix.py @@ -512,7 +512,7 @@ class PrefixModel(DeltaBase): module_list=self.delta_modules) self.delta_modules = None self.reparams = reparams - self.insert_sequential_module(first_modified_module, delta_module=reparams, name="reparams", strict=False) + self.insert_sequential_module(first_modified_module, delta_module=reparams, delta_name="reparams", strict=False) self.mark_as_delta() return module @@ -522,7 +522,7 @@ class PrefixModel(DeltaBase): _, _, ref = self.find_module(module, key) prefixlayer, ref = self.new_module_like(ref) - self.insert_sequential_module(ref, delta_module=prefixlayer, name="prefix") + self.insert_sequential_module(ref, delta_module=prefixlayer, delta_name="prefix") self.delta_modules.append(prefixlayer) def new_module_like(self, module): diff --git a/opendelta/delta_models/soft_prompt.py b/opendelta/delta_models/soft_prompt.py index 0d2fd21..25c2a9c 100644 --- a/opendelta/delta_models/soft_prompt.py +++ b/opendelta/delta_models/soft_prompt.py @@ -193,11 +193,11 @@ class SoftPromptModel(DeltaBase): soft_prompt_layer = self.new_module_like(self.raw_embedding) self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model, delta_module=soft_prompt_layer, - name="soft_prompt_layer" ) + delta_name="soft_prompt_layer" ) def new_module_like(self, module): module_device = get_device(module) - soft_prompt_layer = SoftPromptLayer( + soft_prompt_layer = SoftPromptLayer( soft_token_num = self.soft_token_num, raw_embedding = self.raw_embedding, token_init = self.token_init, diff --git a/opendelta/utils/data_parallel.py b/opendelta/utils/data_parallel.py index 973d21f..ca0c4c0 100644 --- a/opendelta/utils/data_parallel.py +++ b/opendelta/utils/data_parallel.py @@ -8,7 +8,7 @@ def new_replicate_for_data_parallel(self): r""" self is the parent module. """ # rewrite the replicate in DataParallel. - def _caller(_org_func, org_module, delta_name, *args, **kwargs): + def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs): args = args[1:] # the first argument here is ``self`` delta_module = getattr(org_module, delta_name) if hasattr(delta_module, "pre_forward"): @@ -17,6 +17,13 @@ def new_replicate_for_data_parallel(self): if hasattr(delta_module, "post_forward"): ret = delta_module.post_forward(ret) return ret + + def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + ret_1 = _org_func(*args, **kwargs) + ret_2 = delta_module.forward(*args, **kwargs) + return ret_1 + ret_2 replica = self.__new__(type(self)) org_forward = replica.forward replica.__dict__ = self.__dict__.copy() @@ -25,8 +32,13 @@ def new_replicate_for_data_parallel(self): for _delta_info in self._delta_infos: - if _delta_info['method'] == "insert_sequential" and _delta_info['state'] == "on": - new_forward = decorate(replica.forward, _caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) + if _delta_info['state'] == 'on': + if _delta_info['method'] == "insert_sequential": + new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) + elif _delta_info['method'] == "insert_parallel": + new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) + else: + raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported") replica.__dict__['forward'] = new_forward.__get__(replica, type(replica)) # replicas do not have parameters themselves, the replicas reference the original From 266a00e3909b9e82e76be160216c5a3949c7e6e8 Mon Sep 17 00:00:00 2001 From: shengdinghu Date: Sun, 13 Mar 2022 22:04:38 +0800 Subject: [PATCH 2/2] merge parallel --- docs/source/conf.py | 1 + .../configs/lora_roberta-base/lora_mrpc.json | 3 ++- opendelta/delta_models/lora.py | 9 +++++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8408041..e2f01a3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -52,6 +52,7 @@ extensions = [ 'sphinx.ext.autosummary', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', + # 'sphinx.ext.mathbase', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', diff --git a/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json b/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json index 46afef5..10eeb38 100644 --- a/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json +++ b/examples/examples_text-classification/configs/lora_roberta-base/lora_mrpc.json @@ -38,7 +38,8 @@ "tokenizer_name": "roberta-base", "unfrozen_modules": [ "classifier", - "deltas" + "deltas", + "layer_norm" ], "warmup_ratio": 0.06, "weight_decay": 0.1, diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 492fea6..02ebe85 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -66,8 +66,13 @@ class LoraConfig(BaseDeltaConfig): class LoraModel(DeltaBase): r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models `_ . - Thanks for their `loralib `_, we use loralib.linear - to replace the linear layer of the backbone model. + Thanks for their `loralib `_. + + .. note:: + In our implementation, we did not use loralib.linear to replace the linear layer of the backbone model. + Instead, we insert a parallel module into the backbone. + In other words, we treat :math:`(W + A^TB) X` as :math:`WX+ A^TBX`, and insert the :math:`A^TBX` as a parallel insertion module. + If you want to use the original implementation, please refer to `lora_old.py` class attributes: - default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the