modify some style

This commit is contained in:
BUAADreamer 2024-04-25 22:40:53 +08:00
parent d29f3798f6
commit ece78a6d6a
1 changed files with 38 additions and 37 deletions

View File

@ -6,6 +6,7 @@ from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from .utils import Role
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
@ -34,7 +35,7 @@ def preprocess_pretrain_dataset(
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
@ -132,10 +133,10 @@ def preprocess_packed_supervised_dataset(
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
if not all(label == IGNORE_INDEX for label in labels[i: i + block_size]):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs