modify some style
This commit is contained in:
parent
31420f7b31
commit
d29f3798f6
|
@ -6,7 +6,6 @@ from ..extras.constants import IGNORE_INDEX
|
|||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
||||
|
@ -35,7 +34,7 @@ def preprocess_pretrain_dataset(
|
|||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
|
@ -133,10 +132,10 @@ def preprocess_packed_supervised_dataset(
|
|||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
if not all(label == IGNORE_INDEX for label in labels[i: i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
model_inputs["labels"].append(labels[i: i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -280,5 +279,4 @@ def get_preprocess_and_print_func(
|
|||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
|
|
Loading…
Reference in New Issue