From 0ff9a1fb4fdf42d91095f2dc113d9844bfe4368a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 19 Jan 2024 23:29:54 +0800 Subject: [PATCH] set use_reentrant=False --- src/llmtuner/model/patcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index d21d87dc..5ce8d604 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -205,8 +205,7 @@ def _prepare_model_for_training( if not getattr(model, "supports_gradient_checkpointing", False): logger.warning("Current model does not support gradient checkpointing.") else: - model.enable_input_require_grads() - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.config.use_cache = False # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")