update hparams

This commit is contained in:
hiyouga 2024-07-03 23:18:58 +08:00
parent 7f770f6895
commit 575a02a23d
8 changed files with 72 additions and 28 deletions

View File

@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,6 +22,44 @@ import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```
[1, 1, 2, 2, 2, 0]
```
->
```
[[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
return attention_mask_4d
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r""" r"""

View File

@ -177,7 +177,7 @@ def get_dataset(
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func( preprocess_func, print_function = get_preprocess_and_print_func(
data_args, model_args, training_args, stage, template, tokenizer, processor data_args, training_args, stage, template, tokenizer, processor
) )
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}

View File

@ -29,13 +29,12 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments
from .template import Template from .template import Template
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template", template: "Template",
@ -50,7 +49,7 @@ def get_preprocess_and_print_func(
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing or model_args.efficient_packing: if data_args.packing:
preprocess_func = partial( preprocess_func = partial(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,

View File

@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments, ModelArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template
@ -125,7 +125,6 @@ def preprocess_packed_supervised_dataset(
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments", data_args: "DataArguments",
model_args: "ModelArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
@ -161,26 +160,30 @@ def preprocess_packed_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_mask, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index] packed_labels += batch_labels[index]
packed_attention_mask += [i+1]*len(batch_input_ids[index]) if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < data_args.cutoff_len: if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids) pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != data_args.cutoff_len: if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.") raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
if model_args.efficient_packing: model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["attention_mask"].append(packed_attention_mask)
else:
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
return model_inputs return model_inputs

View File

@ -83,9 +83,7 @@ class DataArguments:
) )
ignore_pad_token_for_loss: bool = field( ignore_pad_token_for_loss: bool = field(
default=True, default=True,
metadata={ metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."},
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
},
) )
val_size: float = field( val_size: float = field(
default=0.0, default=0.0,
@ -93,9 +91,11 @@ class DataArguments:
) )
packing: Optional[bool] = field( packing: Optional[bool] = field(
default=None, default=None,
metadata={ metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." )
}, neat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
) )
tool_format: Optional[str] = field( tool_format: Optional[str] = field(
default=None, default=None,
@ -112,3 +112,6 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.neat_packing and not self.packing:
raise ValueError("`neat_packing` requires `packing` is True.")

View File

@ -109,12 +109,6 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
efficient_packing: Optional[bool] = field(
default=None,
metadata={
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field( mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None, default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
@ -232,6 +226,7 @@ class ModelArguments:
self.compute_dtype: Optional["torch.dtype"] = None self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
@ -259,4 +254,5 @@ class ModelArguments:
new_arg.compute_dtype = old_arg.compute_dtype new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length new_arg.model_max_length = old_arg.model_max_length
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg return new_arg

View File

@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage != "sft" and training_args.predict_with_generate: if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.") raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage != "sft" and data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
@ -170,9 +173,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage == "ppo" and model_args.shift_attn: if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage != "sft" and model_args.efficient_packing:
raise ValueError("`efficient_packing` cannot be set as True except SFT.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
@ -314,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args.device_map = {"": get_current_device()} model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
model_args.block_diag_attn = data_args.neat_packing
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary # Log on each process the small summary

View File

@ -28,6 +28,7 @@ from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback