From 6b48308ef9be34d072f3e6bb2444e186a38c2779 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 13 Jul 2024 22:07:58 +0800 Subject: [PATCH] fix #4792 --- src/llamafactory/chat/chat_model.py | 11 +++++------ src/llamafactory/data/tool_utils.py | 4 ++-- src/llamafactory/webui/utils.py | 1 + tests/data/test_formatter.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 5c83fa67..3ea3b44f 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -16,6 +16,7 @@ # limitations under the License. import asyncio +import os from threading import Thread from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence @@ -115,13 +116,11 @@ class ChatModel: def run_chat() -> None: - try: - import platform - - if platform.system() != "Windows": + if os.name != "nt": + try: import readline # noqa: F401 - except ImportError: - print("Install `readline` for a better experience.") + except ImportError: + print("Install `readline` for a better experience.") chat_model = ChatModel() messages = [] diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index ac5565d5..efda86f5 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -25,9 +25,9 @@ DEFAULT_TOOL_PROMPT = ( "You have access to the following tools:\n{tool_text}" "Use the following format if using a tool:\n" "```\n" - "Action: tool name (one of [{tool_names}]).\n" + "Action: tool name (one of [{tool_names}])\n" "Action Input: the input to the tool, in a JSON format representing the kwargs " - """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n""" + """(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n""" "```\n" ) diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 80f53b6a..c52c0887 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -117,6 +117,7 @@ def gen_cmd(args: Dict[str, Any]) -> str: cmd_text = "`\n".join(cmd_lines) else: cmd_text = "\\\n".join(cmd_lines) + cmd_text = "```bash\n{}\n```".format(cmd_text) return cmd_text diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 1845df24..051bc120 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -69,9 +69,9 @@ def test_default_tool_formatter(): " - bar (number): bar_desc\n\n" "Use the following format if using a tool:\n" "```\n" - "Action: tool name (one of [test_tool]).\n" + "Action: tool name (one of [test_tool])\n" "Action Input: the input to the tool, in a JSON format representing the kwargs " - """(e.g. ```{"input": "hello world", "num_beams": 5}```).\n""" + """(e.g. ```{"input": "hello world", "num_beams": 5}```)\n""" "```\n" ]