update arg name
This commit is contained in:
parent
575a02a23d
commit
8a6a7b9c8a
|
@ -78,22 +78,22 @@ TRAINING_STAGES = {
|
||||||
|
|
||||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||||
|
"falcon",
|
||||||
SUPPORTED_CLASS_EFFECIENT_PACKING = [
|
"gemma",
|
||||||
|
"gemma2",
|
||||||
|
"jamba",
|
||||||
"llama",
|
"llama",
|
||||||
"mistral",
|
"mistral",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
|
||||||
"qwen2_moe",
|
|
||||||
"falcon",
|
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
"gemma",
|
"qwen2",
|
||||||
"gemmoe",
|
"qwen2_moe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"jamba"
|
}
|
||||||
]
|
|
||||||
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
|
|
|
@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
# we need to load the model here in order for modeling_* to be available
|
|
||||||
with init_empty_weights():
|
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
|
||||||
modeling_arch = importlib.import_module(module_name)
|
|
||||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
|
||||||
|
|
||||||
# check exist load_balancing_loss_func for moe model
|
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if hasattr(modeling_arch, "load_balancing_loss_func"):
|
if not is_trainable or not model_args.block_diag_attn:
|
||||||
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
|
return
|
||||||
|
|
||||||
|
model_type = getattr(config, "model_type", None)
|
||||||
|
|
||||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
patch_for_block_diag_attn(model_type)
|
||||||
attn_implementation = getattr(config, "attn_implementation", "")
|
|
||||||
else:
|
|
||||||
attn_implementation = getattr(config, "_attn_implementation", "")
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
|
|
||||||
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
|
||||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
raise ValueError("Current model does not support packing sequences for efficient training.")
|
||||||
|
|
|
@ -29,11 +29,12 @@ from .model_utils.checkpointing import prepare_model_for_training
|
||||||
from .model_utils.embedding import resize_embedding_layer
|
from .model_utils.embedding import resize_embedding_layer
|
||||||
from .model_utils.longlora import configure_longlora
|
from .model_utils.longlora import configure_longlora
|
||||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||||
|
from .model_utils.packing import configure_packing
|
||||||
from .model_utils.quantization import configure_quantization
|
from .model_utils.quantization import configure_quantization
|
||||||
from .model_utils.rope import configure_rope
|
from .model_utils.rope import configure_rope
|
||||||
from .model_utils.valuehead import prepare_valuehead_model
|
from .model_utils.valuehead import prepare_valuehead_model
|
||||||
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
|
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
|
||||||
from .model_utils.packing import configure_packing
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
|
@ -73,6 +74,7 @@ def patch_config(
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
configure_visual_model(config)
|
configure_visual_model(config)
|
||||||
|
configure_packing(config, model_args, is_trainable)
|
||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
|
@ -102,9 +104,6 @@ def patch_config(
|
||||||
if init_kwargs.get("device_map", None) == "auto":
|
if init_kwargs.get("device_map", None) == "auto":
|
||||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||||
|
|
||||||
if model_args.efficient_packing:
|
|
||||||
configure_packing(config, model_args)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
|
Loading…
Reference in New Issue