update arg name
This commit is contained in:
parent
575a02a23d
commit
8a6a7b9c8a
|
@ -78,22 +78,22 @@ TRAINING_STAGES = {
|
|||
|
||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
SUPPORTED_CLASS_EFFECIENT_PACKING = [
|
||||
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||
"falcon",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"jamba",
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"starcoder2",
|
||||
"jamba"
|
||||
]
|
||||
}
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
|
|
|
@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
|||
)
|
||||
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_* to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
|
||||
# check exist load_balancing_loss_func for moe model
|
||||
if hasattr(modeling_arch, "load_balancing_loss_func"):
|
||||
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
attn_implementation = getattr(config, "attn_implementation", "")
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", "")
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
|
||||
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||
else:
|
||||
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
||||
raise ValueError("Current model does not support packing sequences for efficient training.")
|
||||
|
|
|
@ -29,11 +29,12 @@ from .model_utils.checkpointing import prepare_model_for_training
|
|||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
from .model_utils.quantization import configure_quantization
|
||||
from .model_utils.rope import configure_rope
|
||||
from .model_utils.valuehead import prepare_valuehead_model
|
||||
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
|
||||
from .model_utils.packing import configure_packing
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
@ -73,6 +74,7 @@ def patch_config(
|
|||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(config, model_args, is_trainable)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
|
@ -101,9 +103,6 @@ def patch_config(
|
|||
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
if model_args.efficient_packing:
|
||||
configure_packing(config, model_args)
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
|
Loading…
Reference in New Issue