update hparams
This commit is contained in:
parent
7f770f6895
commit
575a02a23d
|
@ -1,4 +1,7 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
||||||
|
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,6 +22,44 @@ import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||||
|
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
```
|
||||||
|
[1, 1, 2, 2, 2, 0]
|
||||||
|
```
|
||||||
|
->
|
||||||
|
```
|
||||||
|
[[
|
||||||
|
[
|
||||||
|
[o, x, x, x, x, x],
|
||||||
|
[o, o, x, x, x, x],
|
||||||
|
[x, x, o, x, x, x],
|
||||||
|
[x, x, o, o, x, x],
|
||||||
|
[x, x, o, o, o, x],
|
||||||
|
[x, x, o, x, x, x],
|
||||||
|
]
|
||||||
|
]]
|
||||||
|
```
|
||||||
|
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||||
|
"""
|
||||||
|
bsz, seq_len = attention_mask_with_indices.size()
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||||
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
||||||
|
# Create a block-diagonal mask.
|
||||||
|
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part
|
||||||
|
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
||||||
|
# Invert the attention mask.
|
||||||
|
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
||||||
|
return attention_mask_4d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -177,7 +177,7 @@ def get_dataset(
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
data_args, model_args, training_args, stage, template, tokenizer, processor
|
data_args, training_args, stage, template, tokenizer, processor
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
|
@ -29,13 +29,12 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
model_args: "ModelArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
@ -50,7 +49,7 @@ def get_preprocess_and_print_func(
|
||||||
)
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if data_args.packing or model_args.efficient_packing:
|
if data_args.packing:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_packed_supervised_dataset,
|
preprocess_packed_supervised_dataset,
|
||||||
template=template,
|
template=template,
|
||||||
|
|
|
@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments, ModelArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +125,6 @@ def preprocess_packed_supervised_dataset(
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
model_args: "ModelArguments"
|
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
@ -161,26 +160,30 @@ def preprocess_packed_supervised_dataset(
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||||
for knapsack in knapsacks:
|
for knapsack in knapsacks:
|
||||||
packed_input_ids, packed_attention_mask, packed_labels = [], [], []
|
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||||
for i, length in enumerate(knapsack):
|
for i, length in enumerate(knapsack):
|
||||||
index = length2indexes[length].pop()
|
index = length2indexes[length].pop()
|
||||||
packed_input_ids += batch_input_ids[index]
|
packed_input_ids += batch_input_ids[index]
|
||||||
packed_labels += batch_labels[index]
|
packed_labels += batch_labels[index]
|
||||||
packed_attention_mask += [i+1]*len(batch_input_ids[index])
|
if data_args.neat_packing:
|
||||||
|
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||||
|
else:
|
||||||
|
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||||
|
|
||||||
if len(packed_input_ids) < data_args.cutoff_len:
|
if len(packed_input_ids) < data_args.cutoff_len:
|
||||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||||
packed_labels += [IGNORE_INDEX] * pad_length
|
packed_labels += [IGNORE_INDEX] * pad_length
|
||||||
|
if data_args.neat_packing:
|
||||||
|
packed_attention_masks += [0] * pad_length
|
||||||
|
else:
|
||||||
|
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||||
|
|
||||||
if len(packed_input_ids) != data_args.cutoff_len:
|
if len(packed_input_ids) != data_args.cutoff_len:
|
||||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||||
|
|
||||||
model_inputs["input_ids"].append(packed_input_ids)
|
model_inputs["input_ids"].append(packed_input_ids)
|
||||||
if model_args.efficient_packing:
|
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||||
model_inputs["attention_mask"].append(packed_attention_mask)
|
|
||||||
else:
|
|
||||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
|
||||||
model_inputs["labels"].append(packed_labels)
|
model_inputs["labels"].append(packed_labels)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
|
@ -83,9 +83,7 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: bool = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."},
|
||||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
val_size: float = field(
|
val_size: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
|
@ -93,9 +91,11 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
packing: Optional[bool] = field(
|
packing: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
)
|
||||||
},
|
neat_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||||
)
|
)
|
||||||
tool_format: Optional[str] = field(
|
tool_format: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -112,3 +112,6 @@ class DataArguments:
|
||||||
|
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
|
if self.neat_packing and not self.packing:
|
||||||
|
raise ValueError("`neat_packing` requires `packing` is True.")
|
||||||
|
|
|
@ -109,12 +109,6 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
)
|
)
|
||||||
efficient_packing: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||||
|
@ -232,6 +226,7 @@ class ModelArguments:
|
||||||
self.compute_dtype: Optional["torch.dtype"] = None
|
self.compute_dtype: Optional["torch.dtype"] = None
|
||||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
self.model_max_length: Optional[int] = None
|
self.model_max_length: Optional[int] = None
|
||||||
|
self.block_diag_attn: bool = False
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
@ -259,4 +254,5 @@ class ModelArguments:
|
||||||
new_arg.compute_dtype = old_arg.compute_dtype
|
new_arg.compute_dtype = old_arg.compute_dtype
|
||||||
new_arg.device_map = old_arg.device_map
|
new_arg.device_map = old_arg.device_map
|
||||||
new_arg.model_max_length = old_arg.model_max_length
|
new_arg.model_max_length = old_arg.model_max_length
|
||||||
|
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||||
return new_arg
|
return new_arg
|
||||||
|
|
|
@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
|
if finetuning_args.stage != "sft" and data_args.neat_packing:
|
||||||
|
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
|
@ -170,9 +173,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
if finetuning_args.stage != "sft" and model_args.efficient_packing:
|
|
||||||
raise ValueError("`efficient_packing` cannot be set as True except SFT.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support lora reward model.")
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
|
@ -314,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
|
||||||
model_args.device_map = {"": get_current_device()}
|
model_args.device_map = {"": get_current_device()}
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
model_args.block_diag_attn = data_args.neat_packing
|
||||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||||
|
|
||||||
# Log on each process the small summary
|
# Log on each process the small summary
|
||||||
|
|
|
@ -28,6 +28,7 @@ from ..trainer_utils import create_modelcard_and_push
|
||||||
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
|
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
|
||||||
from .trainer import CustomSeq2SeqTrainer
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue