update packing with sdpa and eager attention mode

This commit is contained in:
ancv 2024-06-16 02:25:47 +07:00
parent 04315c3d92
commit 238f5c3d99
3 changed files with 148 additions and 73 deletions

View File

@ -66,6 +66,21 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SUPPORTED_CLASS_FOR_MULTIPACK = [
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"phi3",
"gemma",
"gemmoe",
"starcoder2",
"jamba"
]
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"

View File

@ -12,7 +12,14 @@ import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
from ...extras.logging import get_logger
from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK
if TYPE_CHECKING:
from transformers import PretrainedConfig
@ -20,19 +27,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments, DataArguments
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"phi3",
"gemma",
"gemmoe",
"starcoder2",
]
logger = get_logger(__name__)
@torch.jit.script
@ -67,6 +62,64 @@ def get_unpad_data(attention_mask: torch.Tensor):
max_seqlen_in_batch,
)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def set_module_name(model, name, value):
if "." in name:
@ -169,57 +222,65 @@ def load_balancing_loss_func(
return overall_loss * num_experts
def patch_for_multipack(model_type, model_name=None):
if model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
def patch_for_multipack(model_type, model_name, attn_implementation):
if attn_implementation == "flash_attention_2":
if model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
else:
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
elif model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
def patch_remote(model_name, config_name, modeling_name):
@ -231,20 +292,19 @@ def patch_remote(model_name, config_name, modeling_name):
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
# check exist load_balancing_loss_func for moe model
if hasattr(modeling_arch, "load_balancing_loss_func"):
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
def configure_packing(config: "PretrainedConfig") -> None:
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
attn_implementation = getattr(config, "attn_implementation", "")
else:
attn_implementation = getattr(config, "_attn_implementation", None)
attn_implementation = getattr(config, "_attn_implementation", "")
if attn_implementation != "flash_attention_2":
raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2")
logger = get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES:
patch_for_multipack(getattr(config, "model_type", None))
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK:
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
else:
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")

View File

@ -34,7 +34,7 @@ def run_sft(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if data_args.efficient_packing:
configure_packing(model.config)
configure_packing(model.config, model_args)
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation