diff --git a/opendelta/delta_models/prefix.py b/opendelta/delta_models/prefix.py index 777a8e1..d1d9e6b 100644 --- a/opendelta/delta_models/prefix.py +++ b/opendelta/delta_models/prefix.py @@ -17,6 +17,7 @@ import torch import opendelta.utils.logging as logging logger = logging.get_logger(__name__) +# We are going to refactor the code of Prefix Tuning. class PrefixLayerT5(nn.Module): r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value