fix llava rlhf

This commit is contained in:
hiyouga 2024-04-28 03:01:49 +08:00
parent 4dbbce21d5
commit b3e33c703e
5 changed files with 79 additions and 43 deletions

View File

@ -1,5 +1,6 @@
from .loader import load_config, load_model, load_tokenizer from .loader import load_config, load_model, load_tokenizer
from .utils.misc import find_all_linear_modules, load_valuehead_params from .utils.misc import find_all_linear_modules
from .utils.valuehead import load_valuehead_params
__all__ = [ __all__ = [

View File

@ -7,9 +7,10 @@ from ..extras.logging import get_logger
from ..extras.misc import count_parameters, try_download_model_from_ms from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils.misc import load_valuehead_params, register_autoclass from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING: if TYPE_CHECKING:
@ -105,7 +106,7 @@ def load_model(
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
model = None model = None
lazy_load = False lazy_load = False
@ -130,7 +131,7 @@ def load_model(
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load: if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable) patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer) register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable) model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

View File

@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization from .utils.quantization import configure_quantization
from .utils.rope import configure_rope from .utils.rope import configure_rope
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
from .utils.visual import autocast_projector_dtype from .utils.visual import autocast_projector_dtype
@ -39,6 +40,7 @@ def patch_config(
model_args: "ModelArguments", model_args: "ModelArguments",
init_kwargs: Dict[str, Any], init_kwargs: Dict[str, Any],
is_trainable: bool, is_trainable: bool,
add_valuehead: bool,
) -> None: ) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
@ -49,6 +51,9 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
if add_valuehead:
configure_valuehead(config)
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.") logger.info("Using KV cache for faster generation.")
@ -73,7 +78,11 @@ def patch_config(
def patch_model( def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
add_valuehead: bool,
) -> None: ) -> None:
gen_config = model.generation_config # check and fix generation config gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and ( if not gen_config.do_sample and (
@ -86,9 +95,8 @@ def patch_model(
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)
if is_trainable and getattr(model.config, "model_type", None) == "chatglm": if add_valuehead:
setattr(model, "lm_head", model.transformer.output_layer) prepare_valuehead_model(model)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if model_args.resize_vocab: if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer) resize_embedding_layer(model, tokenizer)

View File

@ -1,18 +1,13 @@
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, List
import torch import torch
from transformers import PreTrainedModel
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .quantization import QuantizationMethod from .quantization import QuantizationMethod
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -74,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
return module_names return module_names
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}): if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()

View File

@ -0,0 +1,59 @@
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_valuehead(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])