update packing

This commit is contained in:
hiyouga 2024-07-04 01:10:55 +08:00
parent a36e8f2dd5
commit cce7083024
6 changed files with 133 additions and 271 deletions

View File

@ -29,11 +29,12 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
e.g.
```
[1, 1, 2, 2, 2, 0]
[[1, 1, 2, 2, 2, 0]]
```
->
```
[[
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
@ -42,7 +43,8 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""

View File

@ -82,14 +82,11 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"falcon",
"gemma",
"gemma2",
"jamba",
"llama",
"mistral",
"mixtral",
"phi",
"phi3",
"qwen2",
"qwen2_moe",
"starcoder2",
}

View File

@ -83,7 +83,7 @@ class DataArguments:
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."},
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
)
val_size: float = field(
default=0.0,

View File

@ -37,281 +37,102 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
import transformers.models
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments, DataArguments
from ...hparams import ModelArguments
logger = get_logger(__name__)
@torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
@torch.jit.script
def get_unpad_data(attention_mask: torch.Tensor):
device = attention_mask.device
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten()).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = (
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
.to(device=device)
.detach()
)
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)
# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask
def load_balancing_loss_func(
gate_logits: torch.Tensor,
num_experts: torch.Tensor = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Gets the sequnce lengths in the current batch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
e.g.
```
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
[2, 3, 1, 2, 3]
```
"""
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask)
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
attention_mask (`torch.Tensor`, None):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
counts = counts.flatten()
seqlens = counts[counts.nonzero().squeeze()]
return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
r"""
Prepares the indices and seqlens for flash attn varlen function.
Returns:
The auxiliary loss.
indices: indices of non-masked tokens from the flattened sequence.
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
max_seqlen_in_batch: the largest seqlen in the current batch.
e.g.
```
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11]
3
```
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
# ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask
new_attention_mask = (attention_mask != 0).int().to(attention_mask.device)
batch_size, sequence_length = new_attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (
batch_size * sequence_length
)
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
new_attention_mask[None, :, :, None, None]
.expand(
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
)
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
new_attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def patch_for_block_diag_attn(model_type, model_name, attn_implementation):
if attn_implementation == "flash_attention_2":
if model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
def patch_for_block_diag_attn(model_type: str) -> None:
if model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
else:
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@ -319,9 +140,8 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
return
model_type = getattr(config, "model_type", None)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
patch_for_block_diag_attn(model_type)
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support packing sequences for efficient training.")
raise ValueError("Current model does not support block diagonal attention.")

View File

@ -52,4 +52,5 @@ def test_4d_attention_mask():
],
dtype=torch.float16,
)
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
assert torch.all(attention_mask_computed == attention_mask_expected)

View File

@ -0,0 +1,42 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
def test_get_seqlens_in_batch():
attention_mask_with_indices = torch.tensor(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
)
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
assert list(seqlens_in_batch.size()) == [5]
assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3]))
def test_get_unpad_data():
attention_mask_with_indices = torch.tensor(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
)
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]))
assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32))
assert max_seqlen_in_batch == 3