fix #3724
This commit is contained in:
parent
c27afa296b
commit
93a0245474
|
@ -41,9 +41,9 @@ def llama_attention_forward(
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
@ -87,7 +87,7 @@ def llama_attention_forward(
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
@ -125,9 +125,9 @@ def llama_flash_attention_2_forward(
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
@ -233,9 +233,9 @@ def llama_sdpa_attention_forward(
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
@ -270,11 +270,12 @@ def llama_sdpa_attention_forward(
|
||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, :groupsz]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
query_states = query_states.contiguous()
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
key_states = key_states.contiguous()
|
query_states = query_states.contiguous()
|
||||||
value_states = value_states.contiguous()
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
|
|
Loading…
Reference in New Issue