From 8b1172b91085125a83a4150943873141c8bbd8bc Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 02:31:50 +0800 Subject: [PATCH] tiny fix --- .../model/model_utils/longlora.py | 23 +++++++++++-------- src/llamafactory/model/model_utils/visual.py | 4 +++- src/llamafactory/train/ppo/workflow.py | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index af30bd50..5d52c475 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = logging.get_logger(__name__) +transformers_logger = logging.get_logger(__name__) # Modified from: @@ -85,7 +85,7 @@ def llama_attention_forward( assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) num_groups = q_len // groupsz - def shift(state: torch.Tensor) -> torch.Tensor: + def shift(state: "torch.Tensor") -> "torch.Tensor": state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) state = torch.cat( (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), @@ -181,7 +181,7 @@ def llama_flash_attention_2_forward( else: target_dtype = self.q_proj.weight.dtype - logger.warning_once("The input hidden states seems to be silently casted in float32.") + transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.") query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) @@ -191,7 +191,7 @@ def llama_flash_attention_2_forward( assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) num_groups = q_len // groupsz - def shift(state: torch.Tensor) -> torch.Tensor: + def shift(state: "torch.Tensor") -> "torch.Tensor": state = torch.cat( (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), dim=2, @@ -202,7 +202,7 @@ def llama_flash_attention_2_forward( if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - attn_output: torch.Tensor = self._flash_attention_forward( + attn_output: "torch.Tensor" = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate ) @@ -238,7 +238,9 @@ def llama_sdpa_attention_forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: - logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention") + transformers_logger.warning_once( + "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" + ) return llama_attention_forward( self, hidden_states=hidden_states, @@ -275,7 +277,7 @@ def llama_sdpa_attention_forward( assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) num_groups = q_len // groupsz - def shift(state: torch.Tensor) -> torch.Tensor: + def shift(state: "torch.Tensor") -> "torch.Tensor": state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) state = torch.cat( (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), @@ -291,18 +293,19 @@ def llama_sdpa_attention_forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -323,7 +326,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") + require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 700bf470..828a5e6d 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Tuple import torch import transformers.models from transformers.activations import ACT2FN +from transformers.utils import logging from ...extras.logging import get_logger @@ -31,6 +32,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) +transformers_logger = logging.get_logger(__name__) class LlavaMultiModalProjectorForYiVL(torch.nn.Module): @@ -61,7 +63,7 @@ class LlavaMultiModalProjectorForYiVL(torch.nn.Module): else: target_dtype = self.linear_1.weight.dtype - logger.warning_once("The hidden states seems to be silently casted in float32.") + transformers_logger.warning_once("The hidden states seems to be silently casted in float32.") hidden_states = hidden_states.to(target_dtype) return hidden_states diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index df22dae5..f52b80d6 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -54,7 +54,7 @@ def run_ppo( reward_model = create_reward_model(model, model_args, finetuning_args) # Initialize our Trainer - ppo_trainer = CustomPPOTrainer( + ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer( model_args=model_args, training_args=training_args, finetuning_args=finetuning_args,