diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 1824f084..88f666c8 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,6 @@ from .loader import load_config, load_model, load_tokenizer -from .utils.misc import find_all_linear_modules, load_valuehead_params +from .utils.misc import find_all_linear_modules +from .utils.valuehead import load_valuehead_params __all__ = [ diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 0ff7a350..ead6178f 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -7,9 +7,10 @@ from ..extras.logging import get_logger from ..extras.misc import count_parameters, try_download_model_from_ms from .adapter import init_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model -from .utils.misc import load_valuehead_params, register_autoclass +from .utils.misc import register_autoclass from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .utils.unsloth import load_unsloth_pretrained_model +from .utils.valuehead import load_valuehead_params if TYPE_CHECKING: @@ -105,7 +106,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) - patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead) model = None lazy_load = False @@ -130,7 +131,7 @@ def load_model( model = convert_pretrained_model_to_mod(model, config, model_args) if not lazy_load: - patch_model(model, tokenizer, model_args, is_trainable) + patch_model(model, tokenizer, model_args, is_trainable, add_valuehead) register_autoclass(config, model, tokenizer) model = init_adapter(config, model, model_args, finetuning_args, is_trainable) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 94d99644..31cba492 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora from .utils.moe import add_z3_leaf_module, configure_moe from .utils.quantization import configure_quantization from .utils.rope import configure_rope +from .utils.valuehead import configure_valuehead, prepare_valuehead_model from .utils.visual import autocast_projector_dtype @@ -39,6 +40,7 @@ def patch_config( model_args: "ModelArguments", init_kwargs: Dict[str, Any], is_trainable: bool, + add_valuehead: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) @@ -49,6 +51,9 @@ def patch_config( configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) + if add_valuehead: + configure_valuehead(config) + if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) logger.info("Using KV cache for faster generation.") @@ -73,7 +78,11 @@ def patch_config( def patch_model( - model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + is_trainable: bool, + add_valuehead: bool, ) -> None: gen_config = model.generation_config # check and fix generation config if not gen_config.do_sample and ( @@ -86,9 +95,8 @@ def patch_model( if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) - if is_trainable and getattr(model.config, "model_type", None) == "chatglm": - setattr(model, "lm_head", model.transformer.output_layer) - setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + if add_valuehead: + prepare_valuehead_model(model) if model_args.resize_vocab: resize_embedding_layer(model, tokenizer) diff --git a/src/llmtuner/model/utils/misc.py b/src/llmtuner/model/utils/misc.py index 57e772f7..eca68866 100644 --- a/src/llmtuner/model/utils/misc.py +++ b/src/llmtuner/model/utils/misc.py @@ -1,18 +1,13 @@ -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, List import torch -from transformers import PreTrainedModel -from transformers.utils import cached_file -from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ...extras.logging import get_logger from .quantization import QuantizationMethod if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedTokenizer - - from ...hparams import ModelArguments + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer logger = get_logger(__name__) @@ -74,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n return module_names -def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: - r""" - Loads value head parameters from Hugging Face Hub or local disk. - - Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. - """ - kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} - - try: - from safetensors import safe_open - - vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) - with safe_open(vhead_file, framework="pt", device="cpu") as f: - return {key: f.get_tensor(key) for key in f.keys()} - except Exception as err: - logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) - - try: - vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) - return torch.load(vhead_file, map_location="cpu") - except Exception as err: - logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) - - logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) - logger.info("Ignore these messages if you are not resuming the training of a value head model.") - return None - - def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): if "AutoConfig" in getattr(config, "auto_map", {}): config.__class__.register_for_auto_class() diff --git a/src/llmtuner/model/utils/valuehead.py b/src/llmtuner/model/utils/valuehead.py new file mode 100644 index 00000000..a192dcfa --- /dev/null +++ b/src/llmtuner/model/utils/valuehead.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING, Dict + +import torch +from transformers.utils import cached_file + +from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_valuehead(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "llava": + setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None)) + + +def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: + r""" + Loads value head parameters from Hugging Face Hub or local disk. + + Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. + """ + kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} + + try: + from safetensors import safe_open + + vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) + with safe_open(vhead_file, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} + except Exception as err: + logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) + + try: + vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) + return torch.load(vhead_file, map_location="cpu") + except Exception as err: + logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) + + logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) + logger.info("Ignore these messages if you are not resuming the training of a value head model.") + return None + + +def prepare_valuehead_model(model: "PreTrainedModel") -> None: + if getattr(model.config, "model_type", None) == "llava": + setattr(model, "lm_head", model.language_model.get_output_embeddings()) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if getattr(model.config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])