move efficient_packing from data_args to model_args

This commit is contained in:
ancv 2024-07-02 18:37:55 +07:00
parent e8e6af2651
commit e8e13b0942
8 changed files with 20 additions and 18 deletions

View File

@ -177,7 +177,7 @@ def get_dataset(
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func( preprocess_func, print_function = get_preprocess_and_print_func(
data_args, training_args, stage, template, tokenizer, processor data_args, model_args, training_args, stage, template, tokenizer, processor
) )
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}

View File

@ -29,12 +29,13 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments, ModelArguments
from .template import Template from .template import Template
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template", template: "Template",
@ -49,7 +50,7 @@ def get_preprocess_and_print_func(
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing or data_args.efficient_packing: if data_args.packing or model_args.efficient_packing:
preprocess_func = partial( preprocess_func = partial(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,

View File

@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments, ModelArguments
from ..template import Template from ..template import Template
@ -125,6 +125,7 @@ def preprocess_packed_supervised_dataset(
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments", data_args: "DataArguments",
model_args: "ModelArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
@ -176,7 +177,7 @@ def preprocess_packed_supervised_dataset(
raise ValueError("The length of packed example should be identical to the cutoff length.") raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
if data_args.efficient_packing: if model_args.efficient_packing:
model_inputs["attention_mask"].append(packed_attention_mask) model_inputs["attention_mask"].append(packed_attention_mask)
else: else:
model_inputs["attention_mask"].append([1] * data_args.cutoff_len) model_inputs["attention_mask"].append([1] * data_args.cutoff_len)

View File

@ -97,12 +97,6 @@ class DataArguments:
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
}, },
) )
efficient_packing: Optional[bool] = field(
default=None,
metadata={
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
},
)
tool_format: Optional[str] = field( tool_format: Optional[str] = field(
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},

View File

@ -109,6 +109,12 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
efficient_packing: Optional[bool] = field(
default=None,
metadata={
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field( mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None, default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},

View File

@ -170,6 +170,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage == "ppo" and model_args.shift_attn: if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage != "sft" and model_args.efficient_packing:
raise ValueError("`efficient_packing` cannot be set as True except SFT.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")

View File

@ -31,7 +31,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments, DataArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -120,7 +120,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
def load_model( def load_model(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool = False, is_trainable: bool = False,
add_valuehead: bool = False, add_valuehead: bool = False,
@ -130,7 +129,7 @@ def load_model(
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None model = None
lazy_load = False lazy_load = False

View File

@ -39,7 +39,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments, DataArguments, FinetuningArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -54,8 +54,6 @@ def patch_config(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments",
finetune_args: "FinetuningArguments",
init_kwargs: Dict[str, Any], init_kwargs: Dict[str, Any],
is_trainable: bool, is_trainable: bool,
) -> None: ) -> None:
@ -104,7 +102,7 @@ def patch_config(
if init_kwargs.get("device_map", None) == "auto": if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder
if finetune_args.stage == "sft" and data_args.efficient_packing: if model_args.efficient_packing:
configure_packing(config, model_args) configure_packing(config, model_args)