From 3f24337a8a995b145b1e8075bc23878eaa363844 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 11 Jun 2024 01:04:16 +0800 Subject: [PATCH] tiny fix --- src/llamafactory/model/model_utils/longlora.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index cd468979..4a8c562a 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -182,11 +182,9 @@ def llama_flash_attention_2_forward( query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - else: - groupsz = q_len attn_output: torch.Tensor = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate ) if getattr(self.config, "group_size_ratio", None) and self.training: # shift back