From c94e6c9411cc1165ea8180d377611a3ae47956e6 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 13 Jun 2024 03:19:18 +0800 Subject: [PATCH] add quant check in webui export tab --- src/llamafactory/webui/components/export.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 7e1493c8..9d756a38 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -21,6 +21,13 @@ if TYPE_CHECKING: GPTQ_BITS = ["8", "4", "3", "2"] +def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": + if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: + return gr.Dropdown(value="none", interactive=False) + else: + return gr.Dropdown(interactive=True) + + def save_model( lang: str, model_name: str, @@ -96,6 +103,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: export_dir = gr.Textbox() export_hub_model_id = gr.Textbox() + checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path") + checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False) + export_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False)