diff --git a/requirements.txt b/requirements.txt index 12e907e5..fb5820ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch>=1.13.1 transformers>=4.29.1 datasets>=2.12.0 -accelerate>=0.19.0 -peft>=0.3.0 +accelerate>=0.21.0 +peft>=0.4.0 trl>=0.4.7 sentencepiece jieba diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 46ff63b2..146ad353 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,4 +1,4 @@ from llmtuner.chat import ChatModel -__version__ = "0.1.2" +__version__ = "0.1.3" diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 025a37df..a4e2e7ea 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -27,8 +27,8 @@ logger = get_logger(__name__) check_min_version("4.29.1") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") -require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") -require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") +require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") +require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") @@ -81,9 +81,6 @@ def load_model_and_tokenizer( elif model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1") - require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3") - require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index ad3ebd8a..b721ad40 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -84,6 +84,12 @@ class WebChatModel(ChatModel): query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text + response = self.postprocess(response) new_history = history + [(query, response)] chatbot[-1] = [query, response] yield chatbot, new_history + + def postprocess(self, response: str) -> str: + response = response.replace("<", "<") + response = response.replace(">", ">") + return response