This commit is contained in:
hiyouga 2024-06-11 01:04:16 +08:00
parent 91e62a098f
commit 3f24337a8a
1 changed files with 1 additions and 3 deletions

View File

@ -182,11 +182,9 @@ def llama_flash_attention_2_forward(
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
else:
groupsz = q_len
attn_output: torch.Tensor = self._flash_attention_forward( attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
) )
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back