fix data map for packing

This commit is contained in:
hiyouga 2024-07-04 03:01:31 +08:00
parent b03e4a74ba
commit b5d101e1bf
1 changed files with 14 additions and 0 deletions

View File

@ -15,6 +15,8 @@
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
from .processors.feedback import preprocess_feedback_dataset from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset from .processors.pretrain import preprocess_pretrain_dataset
@ -50,6 +52,18 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing:
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
preprocess_func = partial( preprocess_func = partial(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,