|
|
|
@ -37,281 +37,102 @@
|
|
|
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
|
|
|
# SOFTWARE.
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
|
from typing import TYPE_CHECKING, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
|
|
import transformers
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
|
from transformers.modeling_attn_mask_utils import (
|
|
|
|
|
_prepare_4d_causal_attention_mask,
|
|
|
|
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
|
|
)
|
|
|
|
|
from transformers.utils import is_torch_bf16_gpu_available
|
|
|
|
|
import transformers.models
|
|
|
|
|
|
|
|
|
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
|
|
|
|
from ...extras.logging import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
|
|
from ...hparams import ModelArguments, DataArguments
|
|
|
|
|
from ...hparams import ModelArguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
max_num = int(torch.max(attention_mask).item())
|
|
|
|
|
batch_size, _ = attention_mask.shape
|
|
|
|
|
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
for i in range(1, max_num + 1):
|
|
|
|
|
mask = attention_mask == i
|
|
|
|
|
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
result = counts.flatten()
|
|
|
|
|
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
|
|
|
|
return result[nonzero_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def get_unpad_data(attention_mask: torch.Tensor):
|
|
|
|
|
device = attention_mask.device
|
|
|
|
|
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
|
|
|
|
|
indices = torch.nonzero(attention_mask.flatten()).flatten()
|
|
|
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
|
|
cu_seqlens = (
|
|
|
|
|
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
|
|
|
.to(device=device)
|
|
|
|
|
.detach()
|
|
|
|
|
)
|
|
|
|
|
return (
|
|
|
|
|
indices,
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
max_seqlen_in_batch,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def mask_2d_to_4d(
|
|
|
|
|
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
|
|
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
|
|
|
|
when they attend to each other within that sequence.
|
|
|
|
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
|
|
|
|
"""
|
|
|
|
|
bsz, src_len = mask.size()
|
|
|
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
|
|
|
|
|
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
|
|
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
|
|
|
|
|
|
|
|
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
|
|
|
|
binary_mask = torch.where(
|
|
|
|
|
mask != 0,
|
|
|
|
|
torch.tensor(1, device=mask.device).to(dtype),
|
|
|
|
|
torch.tensor(0, device=mask.device).to(dtype),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create a block-diagonal mask.
|
|
|
|
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
|
|
|
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
|
|
|
|
|
|
|
|
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
|
|
|
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
|
|
|
|
mask.device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
|
|
|
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
|
|
|
|
|
|
|
|
|
return masked_zero_one_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patched_prepare_4d_causal_attention_mask(
|
|
|
|
|
attention_mask: Optional[torch.Tensor],
|
|
|
|
|
*args,
|
|
|
|
|
):
|
|
|
|
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
|
|
|
return _prepare_4d_causal_attention_mask(
|
|
|
|
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
|
|
|
*args,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
|
|
|
|
attention_mask: Optional[torch.Tensor],
|
|
|
|
|
*args,
|
|
|
|
|
):
|
|
|
|
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
|
|
|
return _prepare_4d_causal_attention_mask_for_sdpa(
|
|
|
|
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
|
|
|
*args,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_module_name(model, name, value):
|
|
|
|
|
if "." in name:
|
|
|
|
|
parent_name = name.rsplit(".", 1)[0]
|
|
|
|
|
child_name = name[len(parent_name) + 1 :]
|
|
|
|
|
parent = model.get_submodule(parent_name)
|
|
|
|
|
else:
|
|
|
|
|
parent_name = ""
|
|
|
|
|
parent = model
|
|
|
|
|
child_name = name
|
|
|
|
|
|
|
|
|
|
setattr(parent, child_name, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask
|
|
|
|
|
def load_balancing_loss_func(
|
|
|
|
|
gate_logits: torch.Tensor,
|
|
|
|
|
num_experts: torch.Tensor = None,
|
|
|
|
|
top_k=2,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> float:
|
|
|
|
|
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
|
|
|
|
r"""
|
|
|
|
|
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
|
|
|
|
Gets the sequnce lengths in the current batch.
|
|
|
|
|
|
|
|
|
|
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
|
|
|
|
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
|
|
|
|
experts is too unbalanced.
|
|
|
|
|
e.g.
|
|
|
|
|
```
|
|
|
|
|
[
|
|
|
|
|
[1, 1, 2, 2, 2, 0],
|
|
|
|
|
[1, 2, 2, 3, 3, 3],
|
|
|
|
|
]
|
|
|
|
|
```
|
|
|
|
|
->
|
|
|
|
|
```
|
|
|
|
|
[2, 3, 1, 2, 3]
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
bsz = attention_mask.size(0)
|
|
|
|
|
dtype, device = attention_mask.dtype, attention_mask.device
|
|
|
|
|
max_num = torch.max(attention_mask)
|
|
|
|
|
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
|
|
|
|
for i in range(max_num):
|
|
|
|
|
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
|
|
|
|
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
|
|
|
|
shape [batch_size X sequence_length, num_experts].
|
|
|
|
|
attention_mask (`torch.Tensor`, None):
|
|
|
|
|
The attention_mask used in forward function
|
|
|
|
|
shape [batch_size X sequence_length] if not None.
|
|
|
|
|
num_experts (`int`, *optional*):
|
|
|
|
|
Number of experts
|
|
|
|
|
counts = counts.flatten()
|
|
|
|
|
seqlens = counts[counts.nonzero().squeeze()]
|
|
|
|
|
return seqlens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
|
|
|
|
|
r"""
|
|
|
|
|
Prepares the indices and seqlens for flash attn varlen function.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The auxiliary loss.
|
|
|
|
|
indices: indices of non-masked tokens from the flattened sequence.
|
|
|
|
|
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
|
|
|
|
|
max_seqlen_in_batch: the largest seqlen in the current batch.
|
|
|
|
|
|
|
|
|
|
e.g.
|
|
|
|
|
```
|
|
|
|
|
[
|
|
|
|
|
[1, 1, 2, 2, 2, 0],
|
|
|
|
|
[1, 2, 2, 3, 3, 3],
|
|
|
|
|
]
|
|
|
|
|
```
|
|
|
|
|
->
|
|
|
|
|
```
|
|
|
|
|
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
|
|
|
|
[0, 2, 5, 6, 8, 11]
|
|
|
|
|
3
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
if gate_logits is None or not isinstance(gate_logits, tuple):
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
if isinstance(gate_logits, tuple):
|
|
|
|
|
compute_device = gate_logits[0].device
|
|
|
|
|
concatenated_gate_logits = torch.cat(
|
|
|
|
|
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
|
|
|
|
|
|
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
|
|
|
|
|
|
|
|
|
if attention_mask is None:
|
|
|
|
|
# Compute the percentage of tokens routed to each experts
|
|
|
|
|
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
|
|
|
|
|
|
|
|
|
# Compute the average probability of routing to these experts
|
|
|
|
|
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
|
|
|
|
else:
|
|
|
|
|
# ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask
|
|
|
|
|
new_attention_mask = (attention_mask != 0).int().to(attention_mask.device)
|
|
|
|
|
batch_size, sequence_length = new_attention_mask.shape
|
|
|
|
|
num_hidden_layers = concatenated_gate_logits.shape[0] // (
|
|
|
|
|
batch_size * sequence_length
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
|
|
|
|
expert_attention_mask = (
|
|
|
|
|
new_attention_mask[None, :, :, None, None]
|
|
|
|
|
.expand(
|
|
|
|
|
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
|
|
|
|
|
)
|
|
|
|
|
.reshape(-1, top_k, num_experts)
|
|
|
|
|
.to(compute_device)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Compute the percentage of tokens routed to each experts
|
|
|
|
|
tokens_per_expert = torch.sum(
|
|
|
|
|
expert_mask.float() * expert_attention_mask, dim=0
|
|
|
|
|
) / torch.sum(expert_attention_mask, dim=0)
|
|
|
|
|
|
|
|
|
|
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
|
|
|
|
router_per_expert_attention_mask = (
|
|
|
|
|
new_attention_mask[None, :, :, None]
|
|
|
|
|
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
|
|
|
|
.reshape(-1, num_experts)
|
|
|
|
|
.to(compute_device)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Compute the average probability of routing to these experts
|
|
|
|
|
router_prob_per_expert = torch.sum(
|
|
|
|
|
routing_weights * router_per_expert_attention_mask, dim=0
|
|
|
|
|
) / torch.sum(router_per_expert_attention_mask, dim=0)
|
|
|
|
|
|
|
|
|
|
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
|
|
|
|
return overall_loss * num_experts
|
|
|
|
|
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
|
|
|
|
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
|
|
|
return indices, cu_seqlens, max_seqlen_in_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_for_block_diag_attn(model_type, model_name, attn_implementation):
|
|
|
|
|
if attn_implementation == "flash_attention_2":
|
|
|
|
|
if model_type == "llama":
|
|
|
|
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "mistral":
|
|
|
|
|
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "mixtral":
|
|
|
|
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
|
|
|
|
|
load_balancing_loss_func
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "qwen2":
|
|
|
|
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "qwen2_moe":
|
|
|
|
|
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
|
|
|
|
|
load_balancing_loss_func
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "falcon":
|
|
|
|
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "phi":
|
|
|
|
|
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "phi3":
|
|
|
|
|
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "gemma":
|
|
|
|
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
elif model_type == "starcoder2":
|
|
|
|
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
|
|
|
|
get_unpad_data
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
|
|
|
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
|
|
|
|
)
|
|
|
|
|
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
|
|
|
|
patched_prepare_4d_causal_attention_mask
|
|
|
|
|
)
|
|
|
|
|
def patch_for_block_diag_attn(model_type: str) -> None:
|
|
|
|
|
if model_type == "falcon":
|
|
|
|
|
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "gemma":
|
|
|
|
|
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "gemma2":
|
|
|
|
|
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "llama":
|
|
|
|
|
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "mistral":
|
|
|
|
|
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "phi":
|
|
|
|
|
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "phi3":
|
|
|
|
|
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "qwen2":
|
|
|
|
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
|
|
|
|
|
elif model_type == "starcoder2":
|
|
|
|
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
|
|
@ -319,9 +140,8 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
model_type = getattr(config, "model_type", None)
|
|
|
|
|
|
|
|
|
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
|
|
|
|
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
|
|
|
|
patch_for_block_diag_attn(model_type)
|
|
|
|
|
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
|
|
|
|
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Current model does not support packing sequences for efficient training.")
|
|
|
|
|
raise ValueError("Current model does not support block diagonal attention.")
|
|
|
|
|