fix #3560
This commit is contained in:
parent
3c560119ca
commit
0f8f7d3b90
|
@ -15,6 +15,7 @@ from transformers.utils import logging
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -179,10 +180,12 @@ def llama_flash_attention_2_forward(
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||||
|
else:
|
||||||
|
groupsz = q_len
|
||||||
|
|
||||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
@ -299,7 +302,7 @@ def llama_sdpa_attention_forward(
|
||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
|
require_version("transformers==4.40.1", "To fix: pip install transformers==4.40.1")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
@ -309,6 +312,8 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||||
if not is_trainable or not model_args.shift_attn:
|
if not is_trainable or not model_args.shift_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
_apply_llama_patch()
|
_apply_llama_patch()
|
||||||
|
|
Loading…
Reference in New Issue