This commit is contained in:
hiyouga 2024-07-03 02:31:50 +08:00
parent 71cdf8956e
commit 8b1172b910
3 changed files with 17 additions and 12 deletions

View File

@ -43,7 +43,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = logging.get_logger(__name__) transformers_logger = logging.get_logger(__name__)
# Modified from: # Modified from:
@ -85,7 +85,7 @@ def llama_attention_forward(
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor: def shift(state: "torch.Tensor") -> "torch.Tensor":
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat( state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
@ -181,7 +181,7 @@ def llama_flash_attention_2_forward(
else: else:
target_dtype = self.q_proj.weight.dtype target_dtype = self.q_proj.weight.dtype
logger.warning_once("The input hidden states seems to be silently casted in float32.") transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype) query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype) key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
@ -191,7 +191,7 @@ def llama_flash_attention_2_forward(
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor: def shift(state: "torch.Tensor") -> "torch.Tensor":
state = torch.cat( state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2, dim=2,
@ -202,7 +202,7 @@ def llama_flash_attention_2_forward(
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: torch.Tensor = self._flash_attention_forward( attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
) )
@ -238,7 +238,9 @@ def llama_sdpa_attention_forward(
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention") transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
)
return llama_attention_forward( return llama_attention_forward(
self, self,
hidden_states=hidden_states, hidden_states=hidden_states,
@ -275,7 +277,7 @@ def llama_sdpa_attention_forward(
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor: def shift(state: "torch.Tensor") -> "torch.Tensor":
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat( state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
@ -291,18 +293,19 @@ def llama_sdpa_attention_forward(
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
if query_states.device.type == "cuda" and causal_mask is not None: if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, is_causal=is_causal,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
@ -323,7 +326,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Tuple
import torch import torch
import transformers.models import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger from ...extras.logging import get_logger
@ -31,6 +32,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
transformers_logger = logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module): class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
@ -61,7 +63,7 @@ class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
else: else:
target_dtype = self.linear_1.weight.dtype target_dtype = self.linear_1.weight.dtype
logger.warning_once("The hidden states seems to be silently casted in float32.") transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
hidden_states = hidden_states.to(target_dtype) hidden_states = hidden_states.to(target_dtype)
return hidden_states return hidden_states

View File

@ -54,7 +54,7 @@ def run_ppo(
reward_model = create_reward_model(model, model_args, finetuning_args) reward_model = create_reward_model(model, model_args, finetuning_args)
# Initialize our Trainer # Initialize our Trainer
ppo_trainer = CustomPPOTrainer( ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
model_args=model_args, model_args=model_args,
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,