From a153039380d8aa2cdbf434f71f304b1c53ce09f2 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 6 May 2024 23:33:06 +0800 Subject: [PATCH] fix gradio args --- src/llmtuner/webui/interface.py | 12 ++++++++++-- src/webui.py | 7 ++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index b293db90..969ce6bd 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -1,3 +1,5 @@ +import os + from ..extras.packages import is_gradio_available from .common import save_config from .components import ( @@ -69,8 +71,14 @@ def create_web_demo() -> gr.Blocks: def run_web_ui() -> None: - create_ui().queue().launch() + server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") + server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) + gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port) def run_web_demo() -> None: - create_web_demo().queue().launch() + server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") + server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) + gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + create_web_demo().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port) diff --git a/src/webui.py b/src/webui.py index c225c710..b9385259 100644 --- a/src/webui.py +++ b/src/webui.py @@ -1,8 +1,13 @@ +import os + from llmtuner.webui.interface import create_ui def main(): - create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False) + server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") + server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) + gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port) if __name__ == "__main__":