From 96ce76cd2753bc91c781ad13aa8f7a972abe815a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 13 Mar 2024 01:21:50 +0800 Subject: [PATCH] fix kv cache --- src/llmtuner/hparams/model_args.py | 4 ++++ src/llmtuner/model/patcher.py | 23 ++++++++++++++--------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 20b02219..a3719586 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -101,6 +101,10 @@ class ModelArguments: default="offload", metadata={"help": "Path to offload model weights."}, ) + use_cache: bool = field( + default=True, + metadata={"help": "Whether or not to use KV cache in generation."}, + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 7335b1c1..0d8b9d79 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -115,6 +115,9 @@ def _configure_attn_implementation(model_args: "ModelArguments", init_kwargs: Di def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is None: + return + if not hasattr(config, "rope_scaling"): logger.warning("Current model does not support RoPE scaling.") return @@ -141,7 +144,10 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is ) -def _configure_longlora(config: "PretrainedConfig") -> None: +def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.shift_attn: + return + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: setattr(config, "group_size_ratio", 0.25) apply_llama_patch() @@ -242,7 +248,7 @@ def _prepare_model_for_training( # According to: https://github.com/huggingface/transformers/issues/28339 model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.enable_input_require_grads() - model.config.use_cache = False # turn off when gradient checkpointing is enabled + setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: @@ -276,15 +282,14 @@ def patch_config( setattr(config, dtype_name, model_args.compute_dtype == dtype) _configure_attn_implementation(model_args, init_kwargs) - - if model_args.rope_scaling is not None: - _configure_rope(config, model_args, is_trainable) - - if is_trainable and model_args.shift_attn: - _configure_longlora(config) - + _configure_rope(config, model_args, is_trainable) + _configure_longlora(config, model_args, is_trainable) _configure_quantization(config, tokenizer, model_args, init_kwargs) + if model_args.use_cache and not is_trainable: + setattr(config, "use_cache", True) + logger.info("Using KV cache for faster generation.") + init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage