From 0f8f7d3b90b87a118cf63a9c6e534a14ec188335 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 May 2024 19:03:35 +0800 Subject: [PATCH] fix #3560 --- src/llmtuner/model/utils/longlora.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/model/utils/longlora.py b/src/llmtuner/model/utils/longlora.py index c3740a73..c91febdd 100644 --- a/src/llmtuner/model/utils/longlora.py +++ b/src/llmtuner/model/utils/longlora.py @@ -15,6 +15,7 @@ from transformers.utils import logging from transformers.utils.versions import require_version from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN +from ...extras.logging import get_logger if TYPE_CHECKING: @@ -179,10 +180,12 @@ def llama_flash_attention_2_forward( query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: - attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) + else: + groupsz = q_len attn_output: torch.Tensor = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate ) if getattr(self.config, "group_size_ratio", None) and self.training: # shift back @@ -299,7 +302,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0") + require_version("transformers==4.40.1", "To fix: pip install transformers==4.40.1") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward @@ -309,6 +312,8 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", if not is_trainable or not model_args.shift_attn: return + logger = get_logger(__name__) + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: setattr(config, "group_size_ratio", 0.25) _apply_llama_patch()