Update packing.py
This commit is contained in:
parent
c346f79f99
commit
a36e8f2dd5
|
@ -257,7 +257,7 @@ def load_balancing_loss_func(
|
|||
return overall_loss * num_experts
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||
def patch_for_block_diag_attn(model_type, model_name, attn_implementation):
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
|
@ -305,10 +305,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
|||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "jamba":
|
||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||
else:
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
|
@ -318,7 +314,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
|||
)
|
||||
|
||||
|
||||
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue