From 9ce1b0e2f21a4601defe8e8f1f3f312626abe3d8 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 11 Dec 2023 17:13:40 +0800 Subject: [PATCH] use peft 0.7.0, fix #1561 #1764 --- requirements.txt | 2 +- src/llmtuner/model/adapter.py | 3 +++ src/llmtuner/model/loader.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index bbd1f89d..da0b2d88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch>=1.13.1 transformers>=4.31.0,<4.35.0 datasets>=2.14.3 accelerate>=0.21.0 -peft==0.6.0 +peft>=0.7.0 trl>=0.7.4 gradio>=3.38.0,<4.0.0 scipy diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 53dfd6ea..82fa8c7b 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -102,6 +102,9 @@ def init_adapter( ) model = get_peft_model(model, lora_config) + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + if model_args.checkpoint_dir is not None: logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 728472da..2434016e 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -41,7 +41,7 @@ logger = get_logger(__name__) require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"") require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") -require_version("peft==0.6.0", "To fix: pip install peft==0.6.0") +require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")