modify some style
This commit is contained in:
parent
31420f7b31
commit
d29f3798f6
|
@ -6,7 +6,6 @@ from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import Role
|
from .utils import Role
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
||||||
|
@ -18,7 +17,7 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(
|
def preprocess_pretrain_dataset(
|
||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
|
@ -35,7 +34,7 @@ def preprocess_pretrain_dataset(
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
result = {
|
result = {
|
||||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
if data_args.template == "gemma":
|
if data_args.template == "gemma":
|
||||||
|
@ -46,11 +45,11 @@ def preprocess_pretrain_dataset(
|
||||||
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(
|
def preprocess_supervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
template: "Template",
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
processor: "AutoProcessor" = None,
|
processor: "AutoProcessor" = None,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
|
@ -63,14 +62,14 @@ def preprocess_supervised_dataset(
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
template.encode_multiturn(
|
template.encode_multiturn(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
messages,
|
messages,
|
||||||
examples["system"][i],
|
examples["system"][i],
|
||||||
examples["tools"][i],
|
examples["tools"][i],
|
||||||
data_args.cutoff_len,
|
data_args.cutoff_len,
|
||||||
data_args.reserved_label_len,
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
|
@ -96,10 +95,10 @@ def preprocess_supervised_dataset(
|
||||||
|
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(
|
def preprocess_packed_supervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
template: "Template",
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
@ -111,7 +110,7 @@ def preprocess_packed_supervised_dataset(
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
for source_ids, target_ids in template.encode_multiturn(
|
for source_ids, target_ids in template.encode_multiturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||||
):
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
|
@ -133,19 +132,19 @@ def preprocess_packed_supervised_dataset(
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
# split by chunks of cutoff_len
|
||||||
for i in range(0, total_length, block_size):
|
for i in range(0, total_length, block_size):
|
||||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
if not all(label == IGNORE_INDEX for label in labels[i: i + block_size]):
|
||||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
model_inputs["labels"].append(labels[i : i + block_size])
|
model_inputs["labels"].append(labels[i: i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(
|
def preprocess_unsupervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
template: "Template",
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
@ -179,10 +178,10 @@ def preprocess_unsupervised_dataset(
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(
|
def preprocess_pairwise_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
template: "Template",
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
|
@ -246,12 +245,12 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
||||||
|
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
template: "Template",
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
processor: Optional["AutoProcessor"] = None,
|
processor: Optional["AutoProcessor"] = None,
|
||||||
) -> Tuple[Callable, Callable]:
|
) -> Tuple[Callable, Callable]:
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
|
@ -280,5 +279,4 @@ def get_preprocess_and_print_func(
|
||||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
)
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
|
||||||
return preprocess_func, print_function
|
return preprocess_func, print_function
|
||||||
|
|
Loading…
Reference in New Issue