From 7ebd63a609d4bc4f8a645c5c60be77842ebac825 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Nov 2023 16:17:22 +0800 Subject: [PATCH] fix #1418 --- requirements.txt | 2 +- src/llmtuner/tuner/core/loader.py | 2 +- src/llmtuner/tuner/ppo/trainer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 790dce6a..67dcb6b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.31.0,<4.35.0 datasets>=2.12.0 accelerate>=0.21.0 peft>=0.6.0 -trl>=0.7.2 +trl==0.7.2 gradio>=3.38.0,<4.0.0 scipy sentencepiece diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 7cd49e79..15a3f36d 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -41,7 +41,7 @@ require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transform require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0") -require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2") +require_version("trl==0.7.2", "To fix: pip install trl==0.7.2") def load_model_and_tokenizer( diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 372c4891..cdd8f918 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -3,7 +3,7 @@ import sys import math import torch from tqdm import tqdm -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR