supervised packing with greedy knapsack algorithm

This commit is contained in:
ylfeng 2024-05-31 15:33:54 +08:00
parent c4f50865ad
commit f9db439cb7
1 changed files with 92 additions and 10 deletions

View File

@ -1,3 +1,5 @@
import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@ -16,6 +18,52 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def binary_search_for_fit(numbers, capacity):
"""
Perform binary search to find the largest number that fits into the knapsack with the given capacity.
"""
left, right = 0, len(numbers) - 1
result = -1 # If no number fits, return -1
while left <= right:
mid = (left + right) // 2
if numbers[mid] <= capacity:
result = mid
left = mid + 1
else:
right = mid - 1
return result
def efficient_greedy_knapsack(numbers, capacity):
"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # Sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = binary_search_for_fit(numbers, remaining_capacity)
if index == -1:
break # No more numbers fit in this knapsack
# Add the found number to the knapsack and update the remaining capacity
current_knapsack.append(numbers[index])
remaining_capacity -= numbers[index]
# Remove the number from the list
numbers.pop(index)
knapsacks.append(current_knapsack)
return knapsacks
def preprocess_supervised_dataset( def preprocess_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
@ -115,16 +163,50 @@ def preprocess_packed_supervised_dataset(
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
total_length = len(input_ids) # prepare for packing
block_size = data_args.cutoff_len lengths = []
# we drop the small remainder, and if the total_length < block_size, we exclude this batch length2examples_idx = defaultdict(list)
total_length = (total_length // block_size) * block_size for idx, example in enumerate(input_ids):
# split by chunks of cutoff_len length = len(example)
for i in range(0, total_length, block_size): if length > data_args.cutoff_len:
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): logger.warning("Dropped example with length {} > cutoff_len {}".format(length, data_args.cutoff_len))
model_inputs["input_ids"].append(input_ids[i : i + block_size]) continue
model_inputs["attention_mask"].append([1] * block_size) lengths.append(length)
model_inputs["labels"].append(labels[i : i + block_size]) length2examples_idx[length].append(idx)
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids = []
packed_labels = []
total_length = 0
for length in knapsack:
total_length += length
idx = length2examples_idx[length].pop()
packed_input_ids.append(input_ids[idx])
packed_labels.append(labels[idx])
# padding to cutoff_len
if total_length < data_args.cutoff_len:
pad_length = data_args.cutoff_len - total_length
packed_input_ids.append([tokenizer.eos_token_id] * pad_length)
packed_labels.append([IGNORE_INDEX] * pad_length)
elif total_length == data_args.cutoff_len:
pad_length = 0
else:
logger.warning(
"Dropped packed example with total length {} > cutoff_len {}".format(
total_length, data_args.cutoff_len
)
)
continue
# concat all
model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids)))
model_inputs["labels"].append(list(itertools.chain(*packed_labels)))
model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length)
return model_inputs return model_inputs