diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index d90a32ac..188c9f80 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +import bisect +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -16,6 +18,80 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def search_for_fit(numbers: Sequence[int], capacity: int) -> int: + r""" + Finds the index of largest number that fits into the knapsack with the given capacity. + """ + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: + r""" + An efficient greedy algorithm with binary search for the knapsack problem. + """ + numbers.sort() # sort numbers in ascending order for binary search + knapsacks = [] + + while numbers: + current_knapsack = [] + remaining_capacity = capacity + + while True: + index = search_for_fit(numbers, remaining_capacity) + if index == -1: + break # no more numbers fit in this knapsack + + remaining_capacity -= numbers[index] # update the remaining capacity + current_knapsack.append(numbers.pop(index)) # add the number to knapsack + + knapsacks.append(current_knapsack) + + return knapsacks + + +def _encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + messages = prompt + response + input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + + encoded_pairs = template.encode_multiturn( + tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels + + def preprocess_supervised_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -36,41 +112,16 @@ def preprocess_supervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] - - messages = examples["prompt"][i] + examples["response"][i] - input_ids, labels = [], [] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - input_ids += [image_token_id] * getattr(processor, "image_seq_length") - labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - - for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) @@ -90,41 +141,55 @@ def preprocess_packed_supervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - input_ids, labels = [], [] + valid_num = 0 + batch_input_ids, batch_labels = [], [] + lengths = [] + length2indexes = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - messages = examples["prompt"][i] + examples["response"][i] - for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=None, + data_args=data_args, + ) + length = len(input_ids) + if length > data_args.cutoff_len: + logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) + else: + lengths.append(length) + length2indexes[length].append(valid_num) + batch_input_ids.append(input_ids) + batch_labels.append(labels) + valid_num += 1 - input_ids += source_ids + target_ids - labels += source_mask + target_ids + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) + for knapsack in knapsacks: + packed_input_ids, packed_labels = [], [] + for length in knapsack: + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] + if len(packed_input_ids) < data_args.cutoff_len: + pad_length = data_args.cutoff_len - len(packed_input_ids) + packed_input_ids += [tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") + + model_inputs["input_ids"].append(packed_input_ids) + model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + model_inputs["labels"].append(packed_labels) return model_inputs