diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6029d84f..dc326d01 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -78,22 +78,22 @@ TRAINING_STAGES = { STAGES_USE_PAIR_DATA = {"rm", "dpo"} -SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} - -SUPPORTED_CLASS_EFFECIENT_PACKING = [ +SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { + "falcon", + "gemma", + "gemma2", + "jamba", "llama", "mistral", "mixtral", - "qwen2", - "qwen2_moe", - "falcon", "phi", "phi3", - "gemma", - "gemmoe", + "qwen2", + "qwen2_moe", "starcoder2", - "jamba" -] +} + +SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} V_HEAD_WEIGHTS_NAME = "value_head.bin" diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 606cd03b..c60547d4 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation): ) -def patch_remote(model_name, config_name, modeling_name): - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - # we need to load the model here in order for modeling_* to be available - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace(config_name, modeling_name) - modeling_arch = importlib.import_module(module_name) - modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access - # check exist load_balancing_loss_func for moe model - if hasattr(modeling_arch, "load_balancing_loss_func"): - modeling_arch.load_balancing_loss_func = load_balancing_loss_func +def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.block_diag_attn: + return + model_type = getattr(config, "model_type", None) -def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - attn_implementation = getattr(config, "attn_implementation", "") - else: - attn_implementation = getattr(config, "_attn_implementation", "") - - if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING: - patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: + patch_for_block_diag_attn(model_type) logger.info("Using packing sequences without cross-contamination attention for efficient training.") else: - raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") \ No newline at end of file + raise ValueError("Current model does not support packing sequences for efficient training.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 2ddfd21a..a99d38e0 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -29,11 +29,12 @@ from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.packing import configure_packing from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model from .model_utils.visual import autocast_projector_dtype, configure_visual_model -from .model_utils.packing import configure_packing + if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer @@ -73,6 +74,7 @@ def patch_config( configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) configure_visual_model(config) + configure_packing(config, model_args, is_trainable) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) @@ -101,9 +103,6 @@ def patch_config( if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder - - if model_args.efficient_packing: - configure_packing(config, model_args) def patch_model(