forked from p04798526/LLaMA-Factory-Mirror
tiny fix
This commit is contained in:
parent
71cdf8956e
commit
8b1172b910
|
@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
transformers_logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from:
|
# Modified from:
|
||||||
|
@ -85,7 +85,7 @@ def llama_attention_forward(
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
state = torch.cat(
|
state = torch.cat(
|
||||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
|
@ -181,7 +181,7 @@ def llama_flash_attention_2_forward(
|
||||||
else:
|
else:
|
||||||
target_dtype = self.q_proj.weight.dtype
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||||
query_states = query_states.to(target_dtype)
|
query_states = query_states.to(target_dtype)
|
||||||
key_states = key_states.to(target_dtype)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
|
@ -191,7 +191,7 @@ def llama_flash_attention_2_forward(
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
state = torch.cat(
|
state = torch.cat(
|
||||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
dim=2,
|
dim=2,
|
||||||
|
@ -202,7 +202,7 @@ def llama_flash_attention_2_forward(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||||
|
|
||||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -238,7 +238,9 @@ def llama_sdpa_attention_forward(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
|
transformers_logger.warning_once(
|
||||||
|
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||||
|
)
|
||||||
return llama_attention_forward(
|
return llama_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
@ -275,7 +277,7 @@ def llama_sdpa_attention_forward(
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
state = torch.cat(
|
state = torch.cat(
|
||||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
|
@ -291,18 +293,19 @@ def llama_sdpa_attention_forward(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
|
||||||
query_states = query_states.contiguous()
|
query_states = query_states.contiguous()
|
||||||
key_states = key_states.contiguous()
|
key_states = key_states.contiguous()
|
||||||
value_states = value_states.contiguous()
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_mask=causal_mask,
|
attn_mask=causal_mask,
|
||||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
is_causal=causal_mask is None and q_len > 1,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
@ -323,7 +326,7 @@ def llama_sdpa_attention_forward(
|
||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
import torch
|
import torch
|
||||||
import transformers.models
|
import transformers.models
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
@ -31,6 +32,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
transformers_logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||||
|
@ -61,7 +63,7 @@ class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
target_dtype = self.linear_1.weight.dtype
|
target_dtype = self.linear_1.weight.dtype
|
||||||
|
|
||||||
logger.warning_once("The hidden states seems to be silently casted in float32.")
|
transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
|
||||||
hidden_states = hidden_states.to(target_dtype)
|
hidden_states = hidden_states.to(target_dtype)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
|
@ -54,7 +54,7 @@ def run_ppo(
|
||||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
ppo_trainer = CustomPPOTrainer(
|
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
|
||||||
model_args=model_args,
|
model_args=model_args,
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
|
Loading…
Reference in New Issue