tiny fix
This commit is contained in:
parent
91e62a098f
commit
3f24337a8a
|
@ -182,11 +182,9 @@ def llama_flash_attention_2_forward(
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||||
else:
|
|
||||||
groupsz = q_len
|
|
||||||
|
|
||||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
|
Loading…
Reference in New Issue