fix #3724
This commit is contained in:
parent
c27afa296b
commit
93a0245474
|
@ -41,9 +41,9 @@ def llama_attention_forward(
|
|||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -125,9 +125,9 @@ def llama_flash_attention_2_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -233,9 +233,9 @@ def llama_sdpa_attention_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -270,8 +270,9 @@ def llama_sdpa_attention_forward(
|
|||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, :groupsz]
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
|
Loading…
Reference in New Issue