Update supervised.py

This commit is contained in:
hoshi-hiyouga 2024-06-07 03:38:04 +08:00 committed by GitHub
parent b47e317447
commit 8cecade708
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 107 additions and 126 deletions

View File

@ -1,10 +1,10 @@
import itertools
import bisect
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
if TYPE_CHECKING:
@ -18,29 +18,19 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def binary_search_for_fit(numbers, capacity):
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
Perform binary search to find the largest number that fits into the knapsack with the given capacity.
"""
left, right = 0, len(numbers) - 1
result = -1 # If no number fits, return -1
while left <= right:
mid = (left + right) // 2
if numbers[mid] <= capacity:
result = mid
left = mid + 1
else:
right = mid - 1
return result
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def efficient_greedy_knapsack(numbers, capacity):
"""
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # Sort numbers in ascending order for binary search
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
@ -48,22 +38,60 @@ def efficient_greedy_knapsack(numbers, capacity):
remaining_capacity = capacity
while True:
index = binary_search_for_fit(numbers, remaining_capacity)
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # No more numbers fit in this knapsack
break # no more numbers fit in this knapsack
# Add the found number to the knapsack and update the remaining capacity
current_knapsack.append(numbers[index])
remaining_capacity -= numbers[index]
# Remove the number from the list
numbers.pop(index)
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
@ -84,41 +112,16 @@ def preprocess_supervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
@ -138,76 +141,54 @@ def preprocess_packed_supervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
valid_num = 0
batch_input_ids, batch_labels = [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
messages = examples["prompt"][i] + examples["response"][i]
for source_ids, target_ids in template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
):
if data_args.train_on_prompt:
source_mask = source_ids
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids.append(source_ids + target_ids)
labels.append(source_mask + target_ids)
# prepare for packing
lengths = []
length2examples_idx = defaultdict(list)
for idx, example in enumerate(input_ids):
length = len(example)
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=None,
data_args=data_args,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped example with length {} > cutoff_len {}".format(length, data_args.cutoff_len))
continue
lengths.append(length)
length2examples_idx[length].append(idx)
# cutoff_len - 1 for efficient_eos
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos))
for knapsack in knapsacks:
packed_input_ids = []
packed_labels = []
total_length = 0
for length in knapsack:
total_length += length
idx = length2examples_idx[length].pop()
packed_input_ids.append(input_ids[idx])
packed_labels.append(labels[idx])
# padding to cutoff_len
if total_length < data_args.cutoff_len:
pad_length = data_args.cutoff_len - total_length
if template.efficient_eos:
# 确保有 eos
packed_input_ids.append([tokenizer.eos_token_id] * pad_length)
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
else:
# 无 eos 的情况下,使用 0 填充?
packed_input_ids.append([0] * pad_length)
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
elif total_length == data_args.cutoff_len:
pad_length = 0
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
else:
logger.warning(
"Dropped packed example with total length {} > cutoff_len {}".format(
total_length, data_args.cutoff_len
)
)
continue
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
valid_num += 1
# concat all
model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids)))
model_inputs["labels"].append(list(itertools.chain(*packed_labels)))
model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length)
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
if len(packed_input_ids) <= data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
else:
raise ValueError("The length of packed example exceeds the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * len(packed_input_ids))
model_inputs["labels"].append(packed_labels)
return model_inputs