diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 327dfd44..1554345f 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -1,7 +1,7 @@ import os import tiktoken from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union from datasets import load_from_disk @@ -19,6 +19,22 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: + for i in range(len(examples["prompt"])): + query, response = examples["prompt"][i], examples["response"][i] + query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query + history = examples["history"][i] if "history" in examples else None + system = examples["system"][i] if "system" in examples else None + yield query, response, history, system + + +def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: + max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) + max_target_len = max(max_target_len, data_args.reserved_label_len) + max_source_len = data_args.cutoff_len - max_target_len + return max_source_len, max_target_len + + def preprocess_dataset( dataset: Union["Dataset", "IterableDataset"], tokenizer: "PreTrainedTokenizer", @@ -31,14 +47,6 @@ def preprocess_dataset( if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") - def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: - for i in range(len(examples["prompt"])): - query, response = examples["prompt"][i], examples["response"][i] - query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query - history = examples["history"][i] if "history" in examples else None - system = examples["system"][i] if "system" in examples else None - yield query, response, history, system - def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) @@ -79,13 +87,11 @@ def preprocess_dataset( for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( tokenizer, query, response, history, system )): - total_len = len(source_ids) + len(target_ids) - max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len)) - max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len)) - - if len(source_ids) > max_source_len: + source_len, target_len = len(source_ids), len(target_ids) + max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) + if source_len > max_source_len: source_ids = source_ids[:max_source_len] - if len(target_ids) > max_target_len: + if target_len > max_target_len: target_ids = target_ids[:max_target_len] if data_args.train_on_prompt: @@ -187,15 +193,12 @@ def preprocess_dataset( chosen_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id] - total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids)) - max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len)) - max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len)) - - if len(prompt_ids) > max_source_len: + source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids)) + max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) + if source_len > max_source_len: prompt_ids = prompt_ids[:max_source_len] - if len(chosen_ids) > max_target_len: + if target_len > max_target_len: chosen_ids = chosen_ids[:max_target_len] - if len(rejected_ids) > max_target_len: rejected_ids = rejected_ids[:max_target_len] model_inputs["prompt_ids"].append(prompt_ids) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index fb8a0abc..0b74c3cb 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -52,6 +52,10 @@ class DataArguments: default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."} ) + reserved_label_len: Optional[int] = field( + default=1, + metadata={"help": "The maximum length reserved for label after tokenization."} + ) train_on_prompt: Optional[bool] = field( default=False, metadata={"help": "Whether to disable the mask on the prompt or not."} @@ -110,6 +114,9 @@ class DataArguments: ) def __post_init__(self): + if self.reserved_label_len >= self.cutoff_len: + raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.") + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: raise ValueError("Streaming mode should have an integer val size.")