This commit is contained in:
hiyouga 2024-05-07 19:03:35 +08:00
parent 3c560119ca
commit 0f8f7d3b90
1 changed files with 8 additions and 3 deletions

View File

@ -15,6 +15,7 @@ from transformers.utils import logging
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
@ -179,10 +180,12 @@ def llama_flash_attention_2_forward(
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
else:
groupsz = q_len
attn_output: torch.Tensor = self._flash_attention_forward( attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
) )
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
@ -299,7 +302,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0") require_version("transformers==4.40.1", "To fix: pip install transformers==4.40.1")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
@ -309,6 +312,8 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if not is_trainable or not model_args.shift_attn: if not is_trainable or not model_args.shift_attn:
return return
logger = get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch() _apply_llama_patch()