diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 8ed313be..5b21bf4e 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -257,7 +257,7 @@ def load_balancing_loss_func( return overall_loss * num_experts -def patch_for_multipack(model_type, model_name, attn_implementation): +def patch_for_block_diag_attn(model_type, model_name, attn_implementation): if attn_implementation == "flash_attention_2": if model_type == "llama": transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access @@ -305,10 +305,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) - elif model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "jamba": - patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") else: transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa @@ -318,7 +314,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation): ) - def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable or not model_args.block_diag_attn: return