Update packing.py

This commit is contained in:
hoshi-hiyouga 2024-07-03 23:36:01 +08:00 committed by GitHub
parent c346f79f99
commit a36e8f2dd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 6 deletions

View File

@ -257,7 +257,7 @@ def load_balancing_loss_func(
return overall_loss * num_experts return overall_loss * num_experts
def patch_for_multipack(model_type, model_name, attn_implementation): def patch_for_block_diag_attn(model_type, model_name, attn_implementation):
if attn_implementation == "flash_attention_2": if attn_implementation == "flash_attention_2":
if model_type == "llama": if model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
@ -305,10 +305,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
else: else:
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa patched_prepare_4d_causal_attention_mask_for_sdpa
@ -318,7 +314,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
) )
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn: if not is_trainable or not model_args.block_diag_attn:
return return