diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 6d06d1d0..21edab2f 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -51,7 +51,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": allow_methods=["*"], allow_headers=["*"], ) - api_key = os.environ.get("API_KEY", None) + api_key = os.environ.get("API_KEY") security = HTTPBearer(auto_error=False) async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 6d24b244..637b786d 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -53,7 +53,7 @@ class LogCallback(TrainerCallback): self.aborted = False self.do_train = False """ Web UI """ - self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0"))) + self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] if self.webui_mode: signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = LoggerHandler(output_dir) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 8ce25d18..53140efa 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -58,7 +58,7 @@ class AverageMeter: def check_dependencies() -> None: - if int(os.environ.get("DISABLE_VERSION_CHECK", "0")): + if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 91709d40..c5a30113 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -71,12 +71,12 @@ def create_web_demo() -> gr.Blocks: def run_web_ui() -> None: - gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") create_ui().queue().launch(share=gradio_share, server_name=server_name) def run_web_demo() -> None: - gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") create_web_demo().queue().launch(share=gradio_share, server_name=server_name) diff --git a/src/webui.py b/src/webui.py index 3f8690d0..7a43039d 100644 --- a/src/webui.py +++ b/src/webui.py @@ -4,7 +4,7 @@ from llmtuner.webui.interface import create_ui def main(): - gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") create_ui().queue().launch(share=gradio_share, server_name=server_name)