From cce7083024bed4c7429ddc8288d1c9190fde29f5 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 4 Jul 2024 01:10:55 +0800 Subject: [PATCH] update packing --- src/llamafactory/data/collator.py | 24 +- src/llamafactory/extras/constants.py | 3 - src/llamafactory/hparams/data_args.py | 2 +- src/llamafactory/model/model_utils/packing.py | 332 ++++-------------- tests/data/test_collator.py | 1 + tests/model/model_utils/test_packing.py | 42 +++ 6 files changed, 133 insertions(+), 271 deletions(-) create mode 100644 tests/model/model_utils/test_packing.py diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 0939925d..6d176313 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -29,20 +29,22 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype e.g. ``` - [1, 1, 2, 2, 2, 0] + [[1, 1, 2, 2, 2, 0]] ``` -> ``` - [[ - [ - [o, x, x, x, x, x], - [o, o, x, x, x, x], - [x, x, o, x, x, x], - [x, x, o, o, x, x], - [x, x, o, o, o, x], - [x, x, o, x, x, x], - ] - ]] + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, o, x, x, x], + ] + ] + ] ``` where `o` equals to `0.0`, `x` equals to `min_dtype`. """ diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index dc326d01..49aa4dba 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -82,14 +82,11 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { "falcon", "gemma", "gemma2", - "jamba", "llama", "mistral", - "mixtral", "phi", "phi3", "qwen2", - "qwen2_moe", "starcoder2", } diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 38bbbb12..45c1079b 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,7 +83,7 @@ class DataArguments: ) ignore_pad_token_for_loss: bool = field( default=True, - metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."}, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."}, ) val_size: float = field( default=0.0, diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 5b21bf4e..ba614515 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -37,281 +37,102 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Tuple import torch import torch.nn.functional as F - -import importlib - -import transformers -from accelerate import init_empty_weights -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.utils import is_torch_bf16_gpu_available +import transformers.models from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.logging import get_logger + if TYPE_CHECKING: from transformers import PretrainedConfig - from ...hparams import ModelArguments, DataArguments + from ...hparams import ModelArguments logger = get_logger(__name__) -@torch.jit.script -def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: - max_num = int(torch.max(attention_mask).item()) - batch_size, _ = attention_mask.shape - counts = torch.zeros((batch_size, max_num), dtype=torch.int32) - - for i in range(1, max_num + 1): - mask = attention_mask == i - counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) - - result = counts.flatten() - nonzero_indices = torch.nonzero(result).squeeze(-1) - return result[nonzero_indices] - - -@torch.jit.script -def get_unpad_data(attention_mask: torch.Tensor): - device = attention_mask.device - seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) - indices = torch.nonzero(attention_mask.flatten()).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = ( - F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - .to(device=device) - .detach() - ) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - -def mask_2d_to_4d( - mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None -): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - This expansion handles packed sequences so that sequences share the same attention mask integer value - when they attend to each other within that sequence. - This expansion transforms the mask to lower triangular form to prevent future peeking. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask.expand(bsz, 1, tgt_len, src_len) - - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - binary_mask = torch.where( - mask != 0, - torch.tensor(1, device=mask.device).to(dtype), - torch.tensor(0, device=mask.device).to(dtype), - ) - - # Create a block-diagonal mask. - # we multiply by the binary mask so that 0's in the original mask are correctly excluded - zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - - # Now let's create a lower triangular mask of ones that will zero out the upper triangular part - lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( - mask.device - ) - - # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask - masked_zero_one_mask = zero_one_mask * lower_triangular_ones - - return masked_zero_one_mask - - -def patched_prepare_4d_causal_attention_mask( - attention_mask: Optional[torch.Tensor], - *args, -): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask( - mask_2d_to_4d(attention_mask, dtype=dtype), - *args, - ) - - -def patched_prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - *args, -): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask_for_sdpa( - mask_2d_to_4d(attention_mask, dtype=dtype), - *args, - ) - - -def set_module_name(model, name, value): - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - - setattr(parent, child_name, value) - - -# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask -def load_balancing_loss_func( - gate_logits: torch.Tensor, - num_experts: torch.Tensor = None, - top_k=2, - attention_mask: Optional[torch.Tensor] = None, -) -> float: +def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + Gets the sequnce lengths in the current batch. - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. + e.g. + ``` + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ``` + -> + ``` + [2, 3, 1, 2, 3] + ``` + """ + bsz = attention_mask.size(0) + dtype, device = attention_mask.dtype, attention_mask.device + max_num = torch.max(attention_mask) + counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) + for i in range(max_num): + counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) - Args: - gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): - Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of - shape [batch_size X sequence_length, num_experts]. - attention_mask (`torch.Tensor`, None): - The attention_mask used in forward function - shape [batch_size X sequence_length] if not None. - num_experts (`int`, *optional*): - Number of experts + counts = counts.flatten() + seqlens = counts[counts.nonzero().squeeze()] + return seqlens + + +def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: + r""" + Prepares the indices and seqlens for flash attn varlen function. Returns: - The auxiliary loss. + indices: indices of non-masked tokens from the flattened sequence. + cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0. + max_seqlen_in_batch: the largest seqlen in the current batch. + + e.g. + ``` + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ``` + -> + ``` + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] + [0, 2, 5, 6, 8, 11] + 3 + ``` """ - if gate_logits is None or not isinstance(gate_logits, tuple): - return 0 - - if isinstance(gate_logits, tuple): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat( - [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 - ) - - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) - - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - - expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) - - if attention_mask is None: - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.mean(routing_weights, dim=0) - else: - # ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask - new_attention_mask = (attention_mask != 0).int().to(attention_mask.device) - batch_size, sequence_length = new_attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // ( - batch_size * sequence_length - ) - - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask - expert_attention_mask = ( - new_attention_mask[None, :, :, None, None] - .expand( - (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) - ) - .reshape(-1, top_k, num_experts) - .to(compute_device) - ) - - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum( - expert_mask.float() * expert_attention_mask, dim=0 - ) / torch.sum(expert_attention_mask, dim=0) - - # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert - router_per_expert_attention_mask = ( - new_attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) - .to(compute_device) - ) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum( - routing_weights * router_per_expert_attention_mask, dim=0 - ) / torch.sum(router_per_expert_attention_mask, dim=0) - - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) - return overall_loss * num_experts + seqlens_in_batch = get_seqlens_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch -def patch_for_block_diag_attn(model_type, model_name, attn_implementation): - if attn_implementation == "flash_attention_2": - if model_type == "llama": - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "mistral": - transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "mixtral": - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "qwen2": - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2_moe": - transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "falcon": - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi": - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi3": - transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma": - transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "starcoder2": - transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - else: - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access - patched_prepare_4d_causal_attention_mask_for_sdpa - ) - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access - patched_prepare_4d_causal_attention_mask - ) +def patch_for_block_diag_attn(model_type: str) -> None: + if model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data + elif model_type == "gemma2": + transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data + elif model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: @@ -319,9 +140,8 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", return model_type = getattr(config, "model_type", None) - - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: + if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: patch_for_block_diag_attn(model_type) - logger.info("Using packing sequences without cross-contamination attention for efficient training.") + logger.info("Using block diagonal attention for sequence packing without cross-attention.") else: - raise ValueError("Current model does not support packing sequences for efficient training.") + raise ValueError("Current model does not support block diagonal attention.") diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index cb473d4c..58035ac2 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -52,4 +52,5 @@ def test_4d_attention_mask(): ], dtype=torch.float16, ) + assert list(attention_mask_computed.size()) == [2, 1, 6, 6] assert torch.all(attention_mask_computed == attention_mask_expected) diff --git a/tests/model/model_utils/test_packing.py b/tests/model/model_utils/test_packing.py new file mode 100644 index 00000000..6fd9ba3b --- /dev/null +++ b/tests/model/model_utils/test_packing.py @@ -0,0 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data + + +def test_get_seqlens_in_batch(): + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices) + assert list(seqlens_in_batch.size()) == [5] + assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3])) + + +def test_get_unpad_data(): + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices) + assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11])) + assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32)) + assert max_seqlen_in_batch == 3