diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 9a8b97f3..a22c7c11 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -15,6 +15,8 @@ from functools import partial from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple +from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence + from .processors.feedback import preprocess_feedback_dataset from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pretrain import preprocess_pretrain_dataset @@ -50,6 +52,18 @@ def get_preprocess_and_print_func( print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: + if data_args.neat_packing: + + def __init__(self, data, **kwargs): + return TypedSequence.__init__( + self, + data, + type=kwargs.pop("type", None), + try_type=kwargs.pop("try_type", None), + optimized_int_type=kwargs.pop("optimized_int_type", None), + ) + + OptimizedTypedSequence.__init__ = __init__ preprocess_func = partial( preprocess_packed_supervised_dataset, template=template,