fix kv cache

This commit is contained in:
hiyouga 2024-03-13 01:21:50 +08:00
parent 19ef482649
commit 96ce76cd27
2 changed files with 18 additions and 9 deletions

View File

@ -101,6 +101,10 @@ class ModelArguments:
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},

View File

@ -115,6 +115,9 @@ def _configure_attn_implementation(model_args: "ModelArguments", init_kwargs: Di
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
@ -141,7 +144,10 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
)
def _configure_longlora(config: "PretrainedConfig") -> None:
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
@ -242,7 +248,7 @@ def _prepare_model_for_training(
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
@ -276,15 +282,14 @@ def patch_config(
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, init_kwargs)
if model_args.rope_scaling is not None:
_configure_rope(config, model_args, is_trainable)
if is_trainable and model_args.shift_attn:
_configure_longlora(config)
_configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage