tiny fix
This commit is contained in:
parent
71cdf8956e
commit
8b1172b910
|
@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from:
|
||||
|
@ -85,7 +85,7 @@ def llama_attention_forward(
|
|||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
|
@ -181,7 +181,7 @@ def llama_flash_attention_2_forward(
|
|||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
@ -191,7 +191,7 @@ def llama_flash_attention_2_forward(
|
|||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
|
@ -202,7 +202,7 @@ def llama_flash_attention_2_forward(
|
|||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
|
@ -238,7 +238,9 @@ def llama_sdpa_attention_forward(
|
|||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
|
||||
transformers_logger.warning_once(
|
||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||
)
|
||||
return llama_attention_forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
|
@ -275,7 +277,7 @@ def llama_sdpa_attention_forward(
|
|||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
|
@ -291,18 +293,19 @@ def llama_sdpa_attention_forward(
|
|||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
|
@ -323,7 +326,7 @@ def llama_sdpa_attention_forward(
|
|||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Tuple
|
|||
import torch
|
||||
import transformers.models
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
@ -31,6 +32,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||
|
@ -61,7 +63,7 @@ class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
|||
else:
|
||||
target_dtype = self.linear_1.weight.dtype
|
||||
|
||||
logger.warning_once("The hidden states seems to be silently casted in float32.")
|
||||
transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
|
||||
hidden_states = hidden_states.to(target_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
|
|
@ -54,7 +54,7 @@ def run_ppo(
|
|||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
|
|
Loading…
Reference in New Issue