update arg name

This commit is contained in:
hiyouga 2024-07-03 23:23:24 +08:00
parent 575a02a23d
commit 8a6a7b9c8a
3 changed files with 20 additions and 34 deletions

View File

@ -78,22 +78,22 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA = {"rm", "dpo"} STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"falcon",
SUPPORTED_CLASS_EFFECIENT_PACKING = [ "gemma",
"gemma2",
"jamba",
"llama", "llama",
"mistral", "mistral",
"mixtral", "mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi", "phi",
"phi3", "phi3",
"gemma", "qwen2",
"gemmoe", "qwen2_moe",
"starcoder2", "starcoder2",
"jamba" }
]
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
) )
def patch_remote(model_name, config_name, modeling_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
# check exist load_balancing_loss_func for moe model def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if hasattr(modeling_arch, "load_balancing_loss_func"): if not is_trainable or not model_args.block_diag_attn:
modeling_arch.load_balancing_loss_func = load_balancing_loss_func return
model_type = getattr(config, "model_type", None)
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models patch_for_block_diag_attn(model_type)
attn_implementation = getattr(config, "attn_implementation", "")
else:
attn_implementation = getattr(config, "_attn_implementation", "")
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
logger.info("Using packing sequences without cross-contamination attention for efficient training.") logger.info("Using packing sequences without cross-contamination attention for efficient training.")
else: else:
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") raise ValueError("Current model does not support packing sequences for efficient training.")

View File

@ -29,11 +29,12 @@ from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer from .model_utils.embedding import resize_embedding_layer
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model from .model_utils.visual import autocast_projector_dtype, configure_visual_model
from .model_utils.packing import configure_packing
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
@ -73,6 +74,7 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
configure_visual_model(config) configure_visual_model(config)
configure_packing(config, model_args, is_trainable)
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
@ -101,9 +103,6 @@ def patch_config(
if init_kwargs.get("device_map", None) == "auto": if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder
if model_args.efficient_packing:
configure_packing(config, model_args)
def patch_model( def patch_model(