Merge pull request #4053 from hzhaoy/feature/add_select_config_file
Support selecting saved configuration files
This commit is contained in:
commit
0e740aa463
|
@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES
|
||||||
from ...extras.misc import get_device_count
|
from ...extras.misc import get_device_count
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
||||||
from ..utils import change_stage, check_output_dir, list_output_dirs
|
from ..utils import change_stage, check_output_dir, list_output_dirs, list_config_paths
|
||||||
from .data import create_preview_box
|
from .data import create_preview_box
|
||||||
|
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
initial_dir = gr.Textbox(visible=False, interactive=False)
|
initial_dir = gr.Textbox(visible=False, interactive=False)
|
||||||
output_dir = gr.Dropdown(allow_custom_value=True)
|
output_dir = gr.Dropdown(allow_custom_value=True)
|
||||||
config_path = gr.Textbox()
|
config_path = gr.Dropdown(allow_custom_value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
|
device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
|
||||||
|
@ -320,5 +320,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
output_dir.change(
|
output_dir.change(
|
||||||
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
|
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
|
||||||
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
||||||
|
config_path.change(list_config_paths, outputs=[config_path], concurrency_limit=None)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|
|
@ -190,6 +190,18 @@ def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) ->
|
||||||
return gr.Dropdown(choices=output_dirs)
|
return gr.Dropdown(choices=output_dirs)
|
||||||
|
|
||||||
|
|
||||||
|
def list_config_paths() -> "gr.Dropdown":
|
||||||
|
"""
|
||||||
|
Lists all the saved configuration files that can be loaded.
|
||||||
|
"""
|
||||||
|
if os.path.exists(DEFAULT_CONFIG_DIR) and os.path.isdir(DEFAULT_CONFIG_DIR):
|
||||||
|
config_files = [file_name for file_name in os.listdir(DEFAULT_CONFIG_DIR) if file_name.endswith(".yaml")]
|
||||||
|
else:
|
||||||
|
config_files = []
|
||||||
|
|
||||||
|
return gr.Dropdown(choices=config_files)
|
||||||
|
|
||||||
|
|
||||||
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
|
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
|
||||||
r"""
|
r"""
|
||||||
Check if output dir exists.
|
Check if output dir exists.
|
||||||
|
|
Loading…
Reference in New Issue