diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index cd468979..4a8c562a 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -182,11 +182,9 @@ def llama_flash_attention_2_forward( query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - else: - groupsz = q_len attn_output: torch.Tensor = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate ) if getattr(self.config, "group_size_ratio", None) and self.training: # shift back