update packing with sdpa and eager attention mode
This commit is contained in:
parent
04315c3d92
commit
238f5c3d99
|
@ -66,6 +66,21 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
|||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
SUPPORTED_CLASS_FOR_MULTIPACK = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
"jamba"
|
||||
]
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
|
|
|
@ -12,7 +12,14 @@ import importlib
|
|||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
@ -20,19 +27,7 @@ if TYPE_CHECKING:
|
|||
from ...hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
]
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
@ -67,6 +62,64 @@ def get_unpad_data(attention_mask: torch.Tensor):
|
|||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1, device=mask.device).to(dtype),
|
||||
torch.tensor(0, device=mask.device).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
if "." in name:
|
||||
|
@ -169,57 +222,65 @@ def load_balancing_loss_func(
|
|||
return overall_loss * num_experts
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None):
|
||||
if model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||
load_balancing_loss_func
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||
load_balancing_loss_func
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi3":
|
||||
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "jamba":
|
||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||
else:
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
elif model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||
load_balancing_loss_func
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||
load_balancing_loss_func
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi3":
|
||||
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "jamba":
|
||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
|
@ -231,20 +292,19 @@ def patch_remote(model_name, config_name, modeling_name):
|
|||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
|
||||
# check exist load_balancing_loss_func for moe model
|
||||
if hasattr(modeling_arch, "load_balancing_loss_func"):
|
||||
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
|
||||
|
||||
def configure_packing(config: "PretrainedConfig") -> None:
|
||||
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
attn_implementation = getattr(config, "attn_implementation", None)
|
||||
attn_implementation = getattr(config, "attn_implementation", "")
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
attn_implementation = getattr(config, "_attn_implementation", "")
|
||||
|
||||
if attn_implementation != "flash_attention_2":
|
||||
raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
patch_for_multipack(getattr(config, "model_type", None))
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK:
|
||||
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||
else:
|
||||
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
|
@ -34,7 +34,7 @@ def run_sft(
|
|||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if data_args.efficient_packing:
|
||||
configure_packing(model.config)
|
||||
configure_packing(model.config, model_args)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
|
Loading…
Reference in New Issue