This commit is contained in:
hiyouga 2024-05-13 20:09:09 +08:00
parent c27afa296b
commit 93a0245474
1 changed files with 15 additions and 14 deletions

View File

@ -41,9 +41,9 @@ def llama_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -87,7 +87,7 @@ def llama_attention_forward(
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
@ -125,9 +125,9 @@ def llama_flash_attention_2_forward(
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -233,9 +233,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -270,11 +270,12 @@ def llama_sdpa_attention_forward(
causal_mask = attention_mask causal_mask = attention_mask
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask[:, :, :, :groupsz] causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
query_states = query_states.contiguous() if query_states.device.type == "cuda" and causal_mask is not None:
key_states = key_states.contiguous() query_states = query_states.contiguous()
value_states = value_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,