lint
This commit is contained in:
parent
f2580ad403
commit
7daf8366db
|
@ -6,9 +6,10 @@ import peft
|
|||
import torch
|
||||
import transformers
|
||||
import trl
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
from .packages import is_deepspeed_available, is_vllm_available
|
||||
from .packages import is_vllm_available
|
||||
|
||||
|
||||
VERSION = "0.7.2.dev0"
|
||||
|
|
|
@ -20,10 +20,6 @@ def _get_package_version(name: str) -> "Version":
|
|||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_deepspeed_available():
|
||||
return _is_package_available("deepspeed")
|
||||
|
||||
|
||||
def is_fastapi_available():
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES
|
|||
from ...extras.misc import get_device_count
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
||||
from ..utils import change_stage, check_output_dir, list_output_dirs, list_config_paths
|
||||
from ..utils import change_stage, check_output_dir, list_config_paths, list_output_dirs
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
|
@ -257,7 +257,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
initial_dir = gr.Textbox(visible=False, interactive=False)
|
||||
current_time = gr.Textbox(visible=False, interactive=False)
|
||||
output_dir = gr.Dropdown(allow_custom_value=True)
|
||||
config_path = gr.Dropdown(allow_custom_value=True)
|
||||
|
||||
|
@ -284,7 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
arg_load_btn=arg_load_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
initial_dir=initial_dir,
|
||||
current_time=current_time,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
device_count=device_count,
|
||||
|
@ -315,11 +315,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
|
||||
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
|
||||
model_name.change(list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], queue=False)
|
||||
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], queue=False)
|
||||
model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
||||
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
||||
output_dir.change(
|
||||
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
|
||||
list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None
|
||||
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
||||
config_path.change(list_config_paths, outputs=[config_path], concurrency_limit=None)
|
||||
config_path.change(list_config_paths, [current_time], [config_path], queue=False)
|
||||
|
||||
return elem_dict
|
||||
|
|
|
@ -41,7 +41,7 @@ class Engine:
|
|||
|
||||
if not self.pure_chat:
|
||||
current_time = get_time()
|
||||
init_dict["train.initial_dir"] = {"value": "train_{}".format(current_time)}
|
||||
init_dict["train.current_time"] = {"value": current_time}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
|
||||
|
|
|
@ -174,11 +174,24 @@ def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
|||
return str(get_arg_save_path(config_path))
|
||||
|
||||
|
||||
def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown":
|
||||
def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all the saved configuration files.
|
||||
"""
|
||||
config_files = ["{}.yaml".format(current_time)]
|
||||
if os.path.isdir(DEFAULT_CONFIG_DIR):
|
||||
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
|
||||
if file_name.endswith(".yaml"):
|
||||
config_files.append(file_name)
|
||||
|
||||
return gr.Dropdown(choices=config_files)
|
||||
|
||||
|
||||
def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all the directories that can resume from.
|
||||
"""
|
||||
output_dirs = [initial_dir]
|
||||
output_dirs = ["train_{}".format(current_time)]
|
||||
if model_name:
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
|
@ -190,18 +203,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) ->
|
|||
return gr.Dropdown(choices=output_dirs)
|
||||
|
||||
|
||||
def list_config_paths() -> "gr.Dropdown":
|
||||
"""
|
||||
Lists all the saved configuration files that can be loaded.
|
||||
"""
|
||||
if os.path.exists(DEFAULT_CONFIG_DIR) and os.path.isdir(DEFAULT_CONFIG_DIR):
|
||||
config_files = [file_name for file_name in os.listdir(DEFAULT_CONFIG_DIR) if file_name.endswith(".yaml")]
|
||||
else:
|
||||
config_files = []
|
||||
|
||||
return gr.Dropdown(choices=config_files)
|
||||
|
||||
|
||||
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
|
||||
r"""
|
||||
Check if output dir exists.
|
||||
|
|
Loading…
Reference in New Issue