diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index ee28603e..82e7b7f1 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -31,7 +31,10 @@ class WebChatModel(ChatModel): if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model model_name_or_path = os.environ.get("DEMO_MODEL") template = os.environ.get("DEMO_TEMPLATE") - super().__init__(dict(model_name_or_path=model_name_or_path, template=template)) + infer_backend = os.environ.get("DEMO_BACKEND", "huggingface") + super().__init__( + dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend) + ) @property def loaded(self) -> bool: