move efficient_packing from data_args to model_args
This commit is contained in:
parent
e8e6af2651
commit
e8e13b0942
|
@ -177,7 +177,7 @@ def get_dataset(
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
data_args, training_args, stage, template, tokenizer, processor
|
data_args, model_args, training_args, stage, template, tokenizer, processor
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
column_names = list(next(iter(dataset)).keys())
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
|
@ -29,12 +29,13 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
model_args: "ModelArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
@ -49,7 +50,7 @@ def get_preprocess_and_print_func(
|
||||||
)
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if data_args.packing or data_args.efficient_packing:
|
if data_args.packing or model_args.efficient_packing:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_packed_supervised_dataset,
|
preprocess_packed_supervised_dataset,
|
||||||
template=template,
|
template=template,
|
||||||
|
|
|
@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments, ModelArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,6 +125,7 @@ def preprocess_packed_supervised_dataset(
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
model_args: "ModelArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
@ -176,7 +177,7 @@ def preprocess_packed_supervised_dataset(
|
||||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||||
|
|
||||||
model_inputs["input_ids"].append(packed_input_ids)
|
model_inputs["input_ids"].append(packed_input_ids)
|
||||||
if data_args.efficient_packing:
|
if model_args.efficient_packing:
|
||||||
model_inputs["attention_mask"].append(packed_attention_mask)
|
model_inputs["attention_mask"].append(packed_attention_mask)
|
||||||
else:
|
else:
|
||||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||||
|
|
|
@ -97,12 +97,6 @@ class DataArguments:
|
||||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
efficient_packing: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tool_format: Optional[str] = field(
|
tool_format: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||||
|
|
|
@ -109,6 +109,12 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
)
|
)
|
||||||
|
efficient_packing: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
|
||||||
|
},
|
||||||
|
)
|
||||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||||
|
|
|
@ -170,6 +170,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
|
if finetuning_args.stage != "sft" and model_args.efficient_packing:
|
||||||
|
raise ValueError("`efficient_packing` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support lora reward model.")
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ..hparams import FinetuningArguments, ModelArguments, DataArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -120,7 +120,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||||
def load_model(
|
def load_model(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: bool = False,
|
is_trainable: bool = False,
|
||||||
add_valuehead: bool = False,
|
add_valuehead: bool = False,
|
||||||
|
@ -130,7 +129,7 @@ def load_model(
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = load_config(model_args)
|
config = load_config(model_args)
|
||||||
patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
lazy_load = False
|
||||||
|
|
|
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -54,8 +54,6 @@ def patch_config(
|
||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
|
||||||
finetune_args: "FinetuningArguments",
|
|
||||||
init_kwargs: Dict[str, Any],
|
init_kwargs: Dict[str, Any],
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -104,7 +102,7 @@ def patch_config(
|
||||||
if init_kwargs.get("device_map", None) == "auto":
|
if init_kwargs.get("device_map", None) == "auto":
|
||||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||||
|
|
||||||
if finetune_args.stage == "sft" and data_args.efficient_packing:
|
if model_args.efficient_packing:
|
||||||
configure_packing(config, model_args)
|
configure_packing(config, model_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue