diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index e4859ff5..0939925d 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,4 +1,7 @@ -# Copyright 2024 the LlamaFactory team. +# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. +# +# This code is inspired by the OpenAccess AI Collective's axolotl library. +# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +22,44 @@ import torch from transformers import DataCollatorForSeq2Seq +def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": + r""" + Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ``` + [1, 1, 2, 2, 2, 0] + ``` + -> + ``` + [[ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, o, x, x, x], + ] + ]] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + bsz, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + padding_mask = torch.where(expanded_mask != 0, 1, 0) + # Create a block-diagonal mask. + attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask + # Use the lower triangular mask to zero out the upper triangular part + attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + return attention_mask_4d + + @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): r""" diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 5f116e4e..8e7062db 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -177,7 +177,7 @@ def get_dataset( with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( - data_args, model_args, training_args, stage, template, tokenizer, processor + data_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) kwargs = {} diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index ae69e84e..9a8b97f3 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -29,13 +29,12 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments - from ..hparams import DataArguments, ModelArguments + from ..hparams import DataArguments from .template import Template def get_preprocess_and_print_func( data_args: "DataArguments", - model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", @@ -50,7 +49,7 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: - if data_args.packing or model_args.efficient_packing: + if data_args.packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 78811477..747a0c1b 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin - from ...hparams import DataArguments, ModelArguments + from ...hparams import DataArguments from ..template import Template @@ -125,7 +125,6 @@ def preprocess_packed_supervised_dataset( template: "Template", tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", - model_args: "ModelArguments" ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` @@ -161,26 +160,30 @@ def preprocess_packed_supervised_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) for knapsack in knapsacks: - packed_input_ids, packed_attention_mask, packed_labels = [], [], [] + packed_input_ids, packed_attention_masks, packed_labels = [], [], [] for i, length in enumerate(knapsack): index = length2indexes[length].pop() packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] - packed_attention_mask += [i+1]*len(batch_input_ids[index]) + if data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_labels += [IGNORE_INDEX] * pad_length + if data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn if len(packed_input_ids) != data_args.cutoff_len: raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - if model_args.efficient_packing: - model_inputs["attention_mask"].append(packed_attention_mask) - else: - model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) return model_inputs diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 880be84a..38bbbb12 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,9 +83,7 @@ class DataArguments: ) ignore_pad_token_for_loss: bool = field( default=True, - metadata={ - "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." - }, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."}, ) val_size: float = field( default=0.0, @@ -93,9 +91,11 @@ class DataArguments: ) packing: Optional[bool] = field( default=None, - metadata={ - "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." - }, + metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, + ) + neat_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing without cross-attention."}, ) tool_format: Optional[str] = field( default=None, @@ -112,3 +112,6 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.neat_packing and not self.packing: + raise ValueError("`neat_packing` requires `packing` is True.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 49503022..4ac47512 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -109,12 +109,6 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) - efficient_packing: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." - }, - ) mixture_of_depths: Optional[Literal["convert", "load"]] = field( default=None, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, @@ -232,6 +226,7 @@ class ModelArguments: self.compute_dtype: Optional["torch.dtype"] = None self.device_map: Optional[Union[str, Dict[str, Any]]] = None self.model_max_length: Optional[int] = None + self.block_diag_attn: bool = False if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") @@ -259,4 +254,5 @@ class ModelArguments: new_arg.compute_dtype = old_arg.compute_dtype new_arg.device_map = old_arg.device_map new_arg.model_max_length = old_arg.model_max_length + new_arg.block_diag_attn = old_arg.block_diag_attn return new_arg diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 507f7fef..73abc0bb 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + if finetuning_args.stage != "sft" and data_args.neat_packing: + raise ValueError("`neat_packing` cannot be set as True except SFT.") + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") @@ -170,9 +173,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") - if finetuning_args.stage != "sft" and model_args.efficient_packing: - raise ValueError("`efficient_packing` cannot be set as True except SFT.") - if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: raise ValueError("Unsloth does not support lora reward model.") @@ -314,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args.device_map = {"": get_current_device()} model_args.model_max_length = data_args.cutoff_len + model_args.block_diag_attn = data_args.neat_packing data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" # Log on each process the small summary diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index c12a70aa..0c3f9b11 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -28,6 +28,7 @@ from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor from .trainer import CustomSeq2SeqTrainer + if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback