From 7daf8366db0e161d46993fd87cf983a27a0ce2a3 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 03:33:44 +0800 Subject: [PATCH] lint --- src/llamafactory/extras/env.py | 3 ++- src/llamafactory/extras/packages.py | 4 --- src/llamafactory/webui/components/train.py | 14 +++++------ src/llamafactory/webui/engine.py | 2 +- src/llamafactory/webui/utils.py | 29 +++++++++++----------- 5 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py index 059730f1..fdccf86b 100644 --- a/src/llamafactory/extras/env.py +++ b/src/llamafactory/extras/env.py @@ -6,9 +6,10 @@ import peft import torch import transformers import trl +from transformers.integrations import is_deepspeed_available from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available -from .packages import is_deepspeed_available, is_vllm_available +from .packages import is_vllm_available VERSION = "0.7.2.dev0" diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index fe056e2d..4c9e6492 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -20,10 +20,6 @@ def _get_package_version(name: str) -> "Version": return version.parse("0.0.0") -def is_deepspeed_available(): - return _is_package_available("deepspeed") - - def is_fastapi_available(): return _is_package_available("fastapi") diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index eca8f9b3..74f8ef2a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets -from ..utils import change_stage, check_output_dir, list_output_dirs, list_config_paths +from ..utils import change_stage, check_output_dir, list_config_paths, list_output_dirs from .data import create_preview_box @@ -257,7 +257,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): with gr.Column(scale=3): with gr.Row(): - initial_dir = gr.Textbox(visible=False, interactive=False) + current_time = gr.Textbox(visible=False, interactive=False) output_dir = gr.Dropdown(allow_custom_value=True) config_path = gr.Dropdown(allow_custom_value=True) @@ -284,7 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: arg_load_btn=arg_load_btn, start_btn=start_btn, stop_btn=stop_btn, - initial_dir=initial_dir, + current_time=current_time, output_dir=output_dir, config_path=config_path, device_count=device_count, @@ -315,11 +315,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) - model_name.change(list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], queue=False) - finetuning_type.change(list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], queue=False) + model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) + finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) output_dir.change( - list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None + list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None ).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None) - config_path.change(list_config_paths, outputs=[config_path], concurrency_limit=None) + config_path.change(list_config_paths, [current_time], [config_path], queue=False) return elem_dict diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index 00877115..eb6142d3 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -41,7 +41,7 @@ class Engine: if not self.pure_chat: current_time = get_time() - init_dict["train.initial_dir"] = {"value": "train_{}".format(current_time)} + init_dict["train.current_time"] = {"value": current_time} init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)} init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)} init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)} diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 0303aa31..23e62dca 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -174,11 +174,24 @@ def save_args(config_path: str, config_dict: Dict[str, Any]) -> str: return str(get_arg_save_path(config_path)) -def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown": +def list_config_paths(current_time: str) -> "gr.Dropdown": + r""" + Lists all the saved configuration files. + """ + config_files = ["{}.yaml".format(current_time)] + if os.path.isdir(DEFAULT_CONFIG_DIR): + for file_name in os.listdir(DEFAULT_CONFIG_DIR): + if file_name.endswith(".yaml"): + config_files.append(file_name) + + return gr.Dropdown(choices=config_files) + + +def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown": r""" Lists all the directories that can resume from. """ - output_dirs = [initial_dir] + output_dirs = ["train_{}".format(current_time)] if model_name: save_dir = get_save_dir(model_name, finetuning_type) if save_dir and os.path.isdir(save_dir): @@ -190,18 +203,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> return gr.Dropdown(choices=output_dirs) -def list_config_paths() -> "gr.Dropdown": - """ - Lists all the saved configuration files that can be loaded. - """ - if os.path.exists(DEFAULT_CONFIG_DIR) and os.path.isdir(DEFAULT_CONFIG_DIR): - config_files = [file_name for file_name in os.listdir(DEFAULT_CONFIG_DIR) if file_name.endswith(".yaml")] - else: - config_files = [] - - return gr.Dropdown(choices=config_files) - - def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None: r""" Check if output dir exists.