diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 8e7062db..5f116e4e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -177,7 +177,7 @@ def get_dataset( with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( - data_args, training_args, stage, template, tokenizer, processor + data_args, model_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) kwargs = {} diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 3a80900c..ae69e84e 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -29,12 +29,13 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments - from ..hparams import DataArguments + from ..hparams import DataArguments, ModelArguments from .template import Template def get_preprocess_and_print_func( data_args: "DataArguments", + model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", @@ -49,7 +50,7 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: - if data_args.packing or data_args.efficient_packing: + if data_args.packing or model_args.efficient_packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 8ef55321..78811477 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin - from ...hparams import DataArguments + from ...hparams import DataArguments, ModelArguments from ..template import Template @@ -125,6 +125,7 @@ def preprocess_packed_supervised_dataset( template: "Template", tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", + model_args: "ModelArguments" ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` @@ -176,7 +177,7 @@ def preprocess_packed_supervised_dataset( raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - if data_args.efficient_packing: + if model_args.efficient_packing: model_inputs["attention_mask"].append(packed_attention_mask) else: model_inputs["attention_mask"].append([1] * data_args.cutoff_len) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index e351fccf..880be84a 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -97,12 +97,6 @@ class DataArguments: "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." }, ) - efficient_packing: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." - }, - ) tool_format: Optional[str] = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 087c8c38..49503022 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -109,6 +109,12 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) + efficient_packing: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." + }, + ) mixture_of_depths: Optional[Literal["convert", "load"]] = field( default=None, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8b2ea4c1..507f7fef 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -170,6 +170,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") + if finetuning_args.stage != "sft" and model_args.efficient_packing: + raise ValueError("`efficient_packing` cannot be set as True except SFT.") + if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: raise ValueError("Unsloth does not support lora reward model.") diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 43e65d52..fe700d53 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -31,7 +31,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin - from ..hparams import FinetuningArguments, ModelArguments, DataArguments + from ..hparams import FinetuningArguments, ModelArguments logger = get_logger(__name__) @@ -120,7 +120,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_model( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - data_args: "DataArguments", finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, @@ -130,7 +129,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) - patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None lazy_load = False diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index f1831ced..2ddfd21a 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from ..hparams import ModelArguments, DataArguments, FinetuningArguments + from ..hparams import ModelArguments logger = get_logger(__name__) @@ -54,8 +54,6 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - data_args: "DataArguments", - finetune_args: "FinetuningArguments", init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: @@ -104,7 +102,7 @@ def patch_config( if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder - if finetune_args.stage == "sft" and data_args.efficient_packing: + if model_args.efficient_packing: configure_packing(config, model_args)