diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py index 99215fc2..0a7a0f44 100644 --- a/src/llamafactory/webui/components/eval.py +++ b/src/llamafactory/webui/components/eval.py @@ -57,7 +57,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): output_box = gr.Markdown() - output_elems = [output_box, progress_bar] elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, @@ -68,6 +67,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: output_box=output_box, ) ) + output_elems = [output_box, progress_bar] cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None) start_btn.click(engine.runner.run_eval, input_elems, output_elems) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 6f742bb1..93e5dfc1 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -298,22 +298,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) output_elems = [output_box, progress_bar, loss_viewer] - lang = engine.manager.get_elem_by_id("top.lang") - model_name = engine.manager.get_elem_by_id("top.model_name") - finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type") - cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) - arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) - arg_load_btn.click( - engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None - ) start_btn.click(engine.runner.run_train, input_elems, output_elems) stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) - training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) + lang = engine.manager.get_elem_by_id("top.lang") + model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name") + finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type") + + arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) + arg_load_btn.click( + engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None + ) + dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) + model_name.change(list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], queue=False) + finetuning_type.change(list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], queue=False) output_dir.change( list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None ).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 5b11c853..e30feab2 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1475,6 +1475,11 @@ ALERTS = { "ru": "Пожалуйста, выберите адаптер.", "zh": "请选择适配器。", }, + "err_no_output_dir": { + "en": "Please provide output dir.", + "ru": "Пожалуйста, укажите выходную директорию.", + "zh": "请填写输出目录。", + }, "err_no_reward_model": { "en": "Please select a reward model.", "ru": "Пожалуйста, выберите модель вознаграждения.", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 7a305d62..fe213513 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -64,10 +64,15 @@ class Runner: return ALERTS["err_demo"][lang] if do_train: + if not get("train.output_dir"): + return ALERTS["err_no_output_dir"][lang] + stage = TRAINING_STAGES[get("train.training_stage")] - reward_model = get("train.reward_model") - if stage == "ppo" and not reward_model: + if stage == "ppo" and not get("train.reward_model"): return ALERTS["err_no_reward_model"][lang] + else: + if not get("eval.output_dir"): + return ALERTS["err_no_output_dir"][lang] if not from_preview and not is_gpu_or_npu_available(): gr.Warning(ALERTS["warn_no_cuda"][lang]) diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 09cefa0e..0446cb47 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -180,7 +180,7 @@ def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_di r""" Check if output dir exists. """ - if os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): + if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): gr.Warning(ALERTS["warn_output_dir_exists"][lang])