fix data map for packing
This commit is contained in:
parent
b03e4a74ba
commit
b5d101e1bf
|
@ -15,6 +15,8 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||||
|
|
||||||
from .processors.feedback import preprocess_feedback_dataset
|
from .processors.feedback import preprocess_feedback_dataset
|
||||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
||||||
from .processors.pretrain import preprocess_pretrain_dataset
|
from .processors.pretrain import preprocess_pretrain_dataset
|
||||||
|
@ -50,6 +52,18 @@ def get_preprocess_and_print_func(
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if data_args.packing:
|
if data_args.packing:
|
||||||
|
if data_args.neat_packing:
|
||||||
|
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
return TypedSequence.__init__(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
type=kwargs.pop("type", None),
|
||||||
|
try_type=kwargs.pop("try_type", None),
|
||||||
|
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
OptimizedTypedSequence.__init__ = __init__
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_packed_supervised_dataset,
|
preprocess_packed_supervised_dataset,
|
||||||
template=template,
|
template=template,
|
||||||
|
|
Loading…
Reference in New Issue