Update packing.py
This commit is contained in:
parent
c346f79f99
commit
a36e8f2dd5
|
@ -257,7 +257,7 @@ def load_balancing_loss_func(
|
||||||
return overall_loss * num_experts
|
return overall_loss * num_experts
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name, attn_implementation):
|
def patch_for_block_diag_attn(model_type, model_name, attn_implementation):
|
||||||
if attn_implementation == "flash_attention_2":
|
if attn_implementation == "flash_attention_2":
|
||||||
if model_type == "llama":
|
if model_type == "llama":
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
@ -305,10 +305,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
elif model_type == "gemmoe":
|
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
|
||||||
elif model_type == "jamba":
|
|
||||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
|
||||||
else:
|
else:
|
||||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
|
@ -318,7 +314,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if not is_trainable or not model_args.block_diag_attn:
|
if not is_trainable or not model_args.block_diag_attn:
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue