fix #3560
This commit is contained in:
parent
3c560119ca
commit
0f8f7d3b90
|
@ -15,6 +15,7 @@ from transformers.utils import logging
|
|||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -179,10 +180,12 @@ def llama_flash_attention_2_forward(
|
|||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
else:
|
||||
groupsz = q_len
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
|
@ -299,7 +302,7 @@ def llama_sdpa_attention_forward(
|
|||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
|
||||
require_version("transformers==4.40.1", "To fix: pip install transformers==4.40.1")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
@ -309,6 +312,8 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
_apply_llama_patch()
|
||||
|
|
Loading…
Reference in New Issue