diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index e4859ff5..6d176313 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,4 +1,7 @@ -# Copyright 2024 the LlamaFactory team. +# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. +# +# This code is inspired by the OpenAccess AI Collective's axolotl library. +# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +22,46 @@ import torch from transformers import DataCollatorForSeq2Seq +def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": + r""" + Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ``` + [[1, 1, 2, 2, 2, 0]] + ``` + -> + ``` + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, o, x, x, x], + ] + ] + ] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + bsz, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + padding_mask = torch.where(expanded_mask != 0, 1, 0) + # Create a block-diagonal mask. + attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask + # Use the lower triangular mask to zero out the upper triangular part + attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + return attention_mask_4d + + @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): r""" diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index b283542d..747a0c1b 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -160,22 +160,30 @@ def preprocess_packed_supervised_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) for knapsack in knapsacks: - packed_input_ids, packed_labels = [], [] - for length in knapsack: + packed_input_ids, packed_attention_masks, packed_labels = [], [], [] + for i, length in enumerate(knapsack): index = length2indexes[length].pop() packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] + if data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_labels += [IGNORE_INDEX] * pad_length + if data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn if len(packed_input_ids) != data_args.cutoff_len: raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) return model_inputs diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index d8892a96..47781791 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -78,6 +78,18 @@ TRAINING_STAGES = { STAGES_USE_PAIR_DATA = {"rm", "dpo"} +SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { + "falcon", + "gemma", + "gemma2", + "llama", + "mistral", + "phi", + "phi3", + "qwen2", + "starcoder2", +} + SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} V_HEAD_WEIGHTS_NAME = "value_head.bin" diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 880be84a..45c1079b 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,9 +83,7 @@ class DataArguments: ) ignore_pad_token_for_loss: bool = field( default=True, - metadata={ - "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." - }, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."}, ) val_size: float = field( default=0.0, @@ -93,9 +91,11 @@ class DataArguments: ) packing: Optional[bool] = field( default=None, - metadata={ - "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." - }, + metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, + ) + neat_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing without cross-attention."}, ) tool_format: Optional[str] = field( default=None, @@ -112,3 +112,6 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.neat_packing and not self.packing: + raise ValueError("`neat_packing` requires `packing` is True.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 087c8c38..4ac47512 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -226,6 +226,7 @@ class ModelArguments: self.compute_dtype: Optional["torch.dtype"] = None self.device_map: Optional[Union[str, Dict[str, Any]]] = None self.model_max_length: Optional[int] = None + self.block_diag_attn: bool = False if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") @@ -253,4 +254,5 @@ class ModelArguments: new_arg.compute_dtype = old_arg.compute_dtype new_arg.device_map = old_arg.device_map new_arg.model_max_length = old_arg.model_max_length + new_arg.block_diag_attn = old_arg.block_diag_attn return new_arg diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8b2ea4c1..73abc0bb 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + if finetuning_args.stage != "sft" and data_args.neat_packing: + raise ValueError("`neat_packing` cannot be set as True except SFT.") + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") @@ -311,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args.device_map = {"": get_current_device()} model_args.model_max_length = data_args.cutoff_len + model_args.block_diag_attn = data_args.neat_packing data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" # Log on each process the small summary diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py new file mode 100644 index 00000000..ba614515 --- /dev/null +++ b/src/llamafactory/model/model_utils/packing.py @@ -0,0 +1,147 @@ +# Copyright 2024 Musab Gultekin and the LlamaFactory team. +# +# This code is based on the Musab Gultekin's functionary library. +# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2023 Musab Gultekin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import TYPE_CHECKING, Tuple + +import torch +import torch.nn.functional as F +import transformers.models + +from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": + r""" + Gets the sequnce lengths in the current batch. + + e.g. + ``` + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ``` + -> + ``` + [2, 3, 1, 2, 3] + ``` + """ + bsz = attention_mask.size(0) + dtype, device = attention_mask.dtype, attention_mask.device + max_num = torch.max(attention_mask) + counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) + for i in range(max_num): + counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) + + counts = counts.flatten() + seqlens = counts[counts.nonzero().squeeze()] + return seqlens + + +def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: + r""" + Prepares the indices and seqlens for flash attn varlen function. + + Returns: + indices: indices of non-masked tokens from the flattened sequence. + cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0. + max_seqlen_in_batch: the largest seqlen in the current batch. + + e.g. + ``` + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ``` + -> + ``` + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] + [0, 2, 5, 6, 8, 11] + 3 + ``` + """ + seqlens_in_batch = get_seqlens_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + +def patch_for_block_diag_attn(model_type: str) -> None: + if model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data + elif model_type == "gemma2": + transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data + elif model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data + + +def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.block_diag_attn: + return + + model_type = getattr(config, "model_type", None) + if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: + patch_for_block_diag_attn(model_type) + logger.info("Using block diagonal attention for sequence packing without cross-attention.") + else: + raise ValueError("Current model does not support block diagonal attention.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 4eae0bb4..a99d38e0 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -29,6 +29,7 @@ from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.packing import configure_packing from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model @@ -73,6 +74,7 @@ def patch_config( configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) configure_visual_model(config) + configure_packing(config, model_args, is_trainable) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index d4832dd3..9f7e0d2a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -95,11 +95,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): with gr.Column(): - resize_vocab = gr.Checkbox() packing = gr.Checkbox() + neat_packing = gr.Checkbox() with gr.Column(): - upcast_layernorm = gr.Checkbox() + resize_vocab = gr.Checkbox() use_llama_pro = gr.Checkbox() with gr.Column(): @@ -113,9 +113,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: warmup_steps, neftune_alpha, optim, - resize_vocab, packing, - upcast_layernorm, + neat_packing, + resize_vocab, use_llama_pro, shift_attn, report_to, @@ -129,9 +129,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: warmup_steps=warmup_steps, neftune_alpha=neftune_alpha, optim=optim, - resize_vocab=resize_vocab, packing=packing, - upcast_layernorm=upcast_layernorm, + neat_packing=neat_packing, + resize_vocab=resize_vocab, use_llama_pro=use_llama_pro, shift_attn=shift_attn, report_to=report_to, diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 435876e7..affc832f 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -494,20 +494,6 @@ LOCALES = { "info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。", }, }, - "resize_vocab": { - "en": { - "label": "Resize token embeddings", - "info": "Resize the tokenizer vocab and the embedding layers.", - }, - "ru": { - "label": "Изменение размера токенных эмбеддингов", - "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.", - }, - "zh": { - "label": "更改词表大小", - "info": "更改分词器词表和嵌入层的大小。", - }, - }, "packing": { "en": { "label": "Pack sequences", @@ -522,18 +508,32 @@ LOCALES = { "info": "将序列打包为等长样本。", }, }, - "upcast_layernorm": { + "neat_packing": { "en": { - "label": "Upcast LayerNorm", - "info": "Upcast weights of layernorm in float32.", + "label": "Use neat packing", + "info": "Avoid cross-attention between packed sequences.", }, "ru": { - "label": "Приведение весов LayerNorm", - "info": "Приведение весов LayerNorm к float32.", + "label": "Используйте аккуратную упаковку", + "info": "избегайте перекрестного внимания между упакованными последовательностями.", }, "zh": { - "label": "缩放归一化层", - "info": "将归一化层权重缩放至 32 位精度。", + "label": "使用无污染打包", + "info": "避免打包后的序列产生交叉注意力。", + }, + }, + "resize_vocab": { + "en": { + "label": "Resize token embeddings", + "info": "Resize the tokenizer vocab and the embedding layers.", + }, + "ru": { + "label": "Изменение размера токенных эмбеддингов", + "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.", + }, + "zh": { + "label": "更改词表大小", + "info": "更改分词器词表和嵌入层的大小。", }, }, "use_llama_pro": { diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 131d180d..e23f4d15 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -138,9 +138,9 @@ class Runner: warmup_steps=get("train.warmup_steps"), neftune_noise_alpha=get("train.neftune_alpha") or None, optim=get("train.optim"), + packing=get("train.packing") or get("train.neat_packing"), + neat_packing=get("train.neat_packing"), resize_vocab=get("train.resize_vocab"), - packing=get("train.packing"), - upcast_layernorm=get("train.upcast_layernorm"), use_llama_pro=get("train.use_llama_pro"), shift_attn=get("train.shift_attn"), report_to="all" if get("train.report_to") else "none", diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py new file mode 100644 index 00000000..58035ac2 --- /dev/null +++ b/tests/data/test_collator.py @@ -0,0 +1,56 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.data.collator import prepare_4d_attention_mask + + +def test_4d_attention_mask(): + o = 0.0 + x = torch.finfo(torch.float16).min + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + attention_mask_computed = prepare_4d_attention_mask(attention_mask_with_indices, torch.float16) + attention_mask_expected = torch.tensor( + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, x, x, x, x], + ] + ], + [ + [ + [o, x, x, x, x, x], + [x, o, x, x, x, x], + [x, o, o, x, x, x], + [x, x, x, o, x, x], + [x, x, x, o, o, x], + [x, x, x, o, o, o], + ] + ], + ], + dtype=torch.float16, + ) + assert list(attention_mask_computed.size()) == [2, 1, 6, 6] + assert torch.all(attention_mask_computed == attention_mask_expected) diff --git a/tests/model/model_utils/test_packing.py b/tests/model/model_utils/test_packing.py new file mode 100644 index 00000000..6fd9ba3b --- /dev/null +++ b/tests/model/model_utils/test_packing.py @@ -0,0 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data + + +def test_get_seqlens_in_batch(): + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices) + assert list(seqlens_in_batch.size()) == [5] + assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3])) + + +def test_get_unpad_data(): + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices) + assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11])) + assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32)) + assert max_seqlen_in_batch == 3