add resume args in webui
This commit is contained in:
parent
8bf9da659c
commit
06e5d136a4
|
@ -35,6 +35,8 @@ IGNORE_INDEX = -100
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
|
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||||
|
@ -47,10 +49,10 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||||
|
|
||||||
SUPPORTED_MODELS = OrderedDict()
|
SUPPORTED_MODELS = OrderedDict()
|
||||||
|
|
||||||
TRAINER_CONFIG = "trainer_config.yaml"
|
|
||||||
|
|
||||||
TRAINER_LOG = "trainer_log.jsonl"
|
TRAINER_LOG = "trainer_log.jsonl"
|
||||||
|
|
||||||
|
TRAINING_ARGS = "training_args.yaml"
|
||||||
|
|
||||||
TRAINING_STAGES = {
|
TRAINING_STAGES = {
|
||||||
"Supervised Fine-Tuning": "sft",
|
"Supervised Fine-Tuning": "sft",
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
|
|
|
@ -50,7 +50,7 @@ def init_adapter(
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info("Upcasting trainable params to float32.")
|
||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if is_trainable and finetuning_args.finetuning_type == "full":
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
|
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
|
@ -67,7 +67,7 @@ def init_adapter(
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
|
|
||||||
if model_args.visual_inputs:
|
if model_args.visual_inputs:
|
||||||
|
|
|
@ -50,13 +50,6 @@ def get_config_path() -> os.PathLike:
|
||||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def get_arg_save_path(config_path: str) -> os.PathLike:
|
|
||||||
r"""
|
|
||||||
Gets the path to saved arguments.
|
|
||||||
"""
|
|
||||||
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
|
||||||
|
|
||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
def load_config() -> Dict[str, Any]:
|
||||||
r"""
|
r"""
|
||||||
Loads user config if exists.
|
Loads user config if exists.
|
||||||
|
@ -77,24 +70,28 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||||
user_config["lang"] = lang or user_config["lang"]
|
user_config["lang"] = lang or user_config["lang"]
|
||||||
if model_name:
|
if model_name:
|
||||||
user_config["last_model"] = model_name
|
user_config["last_model"] = model_name
|
||||||
|
|
||||||
|
if model_name and model_path:
|
||||||
user_config["path_dict"][model_name] = model_path
|
user_config["path_dict"][model_name] = model_path
|
||||||
|
|
||||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||||
safe_dump(user_config, f)
|
safe_dump(user_config, f)
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name: str) -> Optional[str]:
|
def get_model_path(model_name: str) -> str:
|
||||||
r"""
|
r"""
|
||||||
Gets the model path according to the model name.
|
Gets the model path according to the model name.
|
||||||
"""
|
"""
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
|
||||||
if (
|
if (
|
||||||
use_modelscope()
|
use_modelscope()
|
||||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||||
): # replace path
|
): # replace path
|
||||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
|
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,8 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
visual_inputs = gr.Checkbox(scale=1)
|
visual_inputs = gr.Checkbox(scale=1)
|
||||||
|
|
||||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
|
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
|
||||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||||
|
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||||
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
|
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
|
||||||
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES
|
||||||
from ...extras.misc import get_device_count
|
from ...extras.misc import get_device_count
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
||||||
from ..utils import change_stage, check_output_dir, list_config_paths, list_output_dirs
|
from ..utils import change_stage, list_config_paths, list_output_dirs
|
||||||
from .data import create_preview_box
|
from .data import create_preview_box
|
||||||
|
|
||||||
|
|
||||||
|
@ -319,7 +319,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
||||||
output_dir.change(
|
output_dir.change(
|
||||||
list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None
|
list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None
|
||||||
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
)
|
||||||
|
output_dir.input(
|
||||||
|
engine.runner.check_output_dir,
|
||||||
|
[lang, model_name, finetuning_type, output_dir],
|
||||||
|
list(input_elems) + [output_box],
|
||||||
|
concurrency_limit=None,
|
||||||
|
)
|
||||||
config_path.change(list_config_paths, [current_time], [config_path], queue=False)
|
config_path.change(list_config_paths, [current_time], [config_path], queue=False)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|
|
@ -5,11 +5,11 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||||
|
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
|
||||||
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
from .common import DEFAULT_CACHE_DIR, get_save_dir, load_config
|
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS, LOCALES
|
||||||
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||||
|
|
||||||
|
|
||||||
|
@ -276,6 +276,10 @@ class Runner:
|
||||||
else:
|
else:
|
||||||
self.do_train, self.running_data = do_train, data
|
self.do_train, self.running_data = do_train, data
|
||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
|
|
||||||
|
os.makedirs(args["output_dir"], exist_ok=True)
|
||||||
|
save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data))
|
||||||
|
|
||||||
env = deepcopy(os.environ)
|
env = deepcopy(os.environ)
|
||||||
env["LLAMABOARD_ENABLED"] = "1"
|
env["LLAMABOARD_ENABLED"] = "1"
|
||||||
if args.get("deepspeed", None) is not None:
|
if args.get("deepspeed", None) is not None:
|
||||||
|
@ -284,6 +288,16 @@ class Runner:
|
||||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
|
|
||||||
|
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||||
|
config_dict = {}
|
||||||
|
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
||||||
|
for elem, value in data.items():
|
||||||
|
elem_id = self.manager.get_id_by_elem(elem)
|
||||||
|
if elem_id not in skip_ids:
|
||||||
|
config_dict[elem_id] = value
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
def preview_train(self, data):
|
def preview_train(self, data):
|
||||||
yield from self._preview(data, do_train=True)
|
yield from self._preview(data, do_train=True)
|
||||||
|
|
||||||
|
@ -349,28 +363,24 @@ class Runner:
|
||||||
}
|
}
|
||||||
yield return_dict
|
yield return_dict
|
||||||
|
|
||||||
def save_args(self, data: dict):
|
def save_args(self, data):
|
||||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
error = self._initialize(data, do_train=True, from_preview=True)
|
error = self._initialize(data, do_train=True, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
return {output_box: error}
|
return {output_box: error}
|
||||||
|
|
||||||
config_dict: Dict[str, Any] = {}
|
|
||||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||||
for elem, value in data.items():
|
save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
||||||
elem_id = self.manager.get_id_by_elem(elem)
|
|
||||||
if elem_id not in skip_ids:
|
|
||||||
config_dict[elem_id] = value
|
|
||||||
|
|
||||||
save_path = save_args(config_path, config_dict)
|
save_args(save_path, self._form_config_dict(data))
|
||||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||||
|
|
||||||
def load_args(self, lang: str, config_path: str):
|
def load_args(self, lang: str, config_path: str):
|
||||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
config_dict = load_args(config_path)
|
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
|
||||||
if config_dict is None:
|
if config_dict is None:
|
||||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||||
return {output_box: ALERTS["err_config_not_found"][lang]}
|
return {output_box: ALERTS["err_config_not_found"][lang]}
|
||||||
|
@ -380,3 +390,17 @@ class Runner:
|
||||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||||
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
|
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
|
||||||
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
|
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
|
||||||
|
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
||||||
|
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
||||||
|
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
|
||||||
|
|
||||||
|
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
|
||||||
|
config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG)) # load llamaboard config
|
||||||
|
for elem_id, value in config_dict.items():
|
||||||
|
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||||
|
|
||||||
|
return output_dict
|
||||||
|
|
|
@ -8,10 +8,10 @@ import psutil
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from yaml import safe_dump, safe_load
|
from yaml import safe_dump, safe_load
|
||||||
|
|
||||||
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
|
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
|
||||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||||
from ..extras.ploting import gen_loss_plot
|
from ..extras.ploting import gen_loss_plot
|
||||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
|
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,10 +93,10 @@ def save_cmd(args: Dict[str, Any]) -> str:
|
||||||
output_dir = args["output_dir"]
|
output_dir = args["output_dir"]
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
|
||||||
safe_dump(clean_cmd(args), f)
|
safe_dump(clean_cmd(args), f)
|
||||||
|
|
||||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
return os.path.join(output_dir, TRAINING_ARGS)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(path: os.PathLike) -> str:
|
def get_eval_results(path: os.PathLike) -> str:
|
||||||
|
@ -157,22 +157,19 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||||
Loads saved arguments.
|
Loads saved arguments.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
return safe_load(f)
|
return safe_load(f)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
def save_args(config_path: str, config_dict: Dict[str, Any]):
|
||||||
r"""
|
r"""
|
||||||
Saves arguments.
|
Saves arguments.
|
||||||
"""
|
"""
|
||||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
|
|
||||||
safe_dump(config_dict, f)
|
safe_dump(config_dict, f)
|
||||||
|
|
||||||
return str(get_arg_save_path(config_path))
|
|
||||||
|
|
||||||
|
|
||||||
def list_config_paths(current_time: str) -> "gr.Dropdown":
|
def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||||
r"""
|
r"""
|
||||||
|
@ -181,13 +178,13 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||||
config_files = ["{}.yaml".format(current_time)]
|
config_files = ["{}.yaml".format(current_time)]
|
||||||
if os.path.isdir(DEFAULT_CONFIG_DIR):
|
if os.path.isdir(DEFAULT_CONFIG_DIR):
|
||||||
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
|
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
|
||||||
if file_name.endswith(".yaml"):
|
if file_name.endswith(".yaml") and file_name not in config_files:
|
||||||
config_files.append(file_name)
|
config_files.append(file_name)
|
||||||
|
|
||||||
return gr.Dropdown(choices=config_files)
|
return gr.Dropdown(choices=config_files)
|
||||||
|
|
||||||
|
|
||||||
def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||||
r"""
|
r"""
|
||||||
Lists all the directories that can resume from.
|
Lists all the directories that can resume from.
|
||||||
"""
|
"""
|
||||||
|
@ -203,14 +200,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -
|
||||||
return gr.Dropdown(choices=output_dirs)
|
return gr.Dropdown(choices=output_dirs)
|
||||||
|
|
||||||
|
|
||||||
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
|
|
||||||
r"""
|
|
||||||
Check if output dir exists.
|
|
||||||
"""
|
|
||||||
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
|
||||||
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
|
||||||
|
|
||||||
|
|
||||||
def create_ds_config() -> None:
|
def create_ds_config() -> None:
|
||||||
r"""
|
r"""
|
||||||
Creates deepspeed config.
|
Creates deepspeed config.
|
||||||
|
|
Loading…
Reference in New Issue