better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui
This commit is contained in:
hiyouga 2024-05-29 23:55:38 +08:00
parent d0aa36b8ad
commit 8070871732
14 changed files with 303 additions and 193 deletions

View File

@ -1,4 +1,4 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION

View File

@ -2,6 +2,19 @@ from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
}
CHOICES = ["A", "B", "C", "D"]
@ -26,9 +39,9 @@ LAYERNORM_NAMES = {"norm", "ln"}
METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
PEFT_METHODS = ["lora"]
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
@ -49,9 +62,9 @@ TRAINING_STAGES = {
"Pre-Training": "pt",
}
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -11,6 +11,7 @@ from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
@ -255,13 +256,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and any(
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
)
)
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (
finetuning_args.stage in ["rm", "ppo"]

View File

@ -6,6 +6,7 @@ from numpy.typing import NDArray
from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir
@ -44,13 +45,14 @@ class WebChatModel(ChatModel):
def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
elif not get("top.model_name"):
elif not model_name:
error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"):
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
@ -60,21 +62,10 @@ class WebChatModel(ChatModel):
yield error
return
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
finetuning_type=get("top.finetuning_type"),
model_name_or_path=model_path,
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@ -83,8 +74,16 @@ class WebChatModel(ChatModel):
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
)
super().__init__(args)
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
super().__init__(args)
yield ALERTS["info_loaded"][lang]
def unload_model(self, data) -> Generator[str, None, None]:

View File

@ -1,12 +1,12 @@
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from yaml import safe_dump, safe_load
from ..extras.constants import (
CHECKPOINT_NAMES,
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
@ -29,7 +29,6 @@ if is_gradio_available():
logger = get_logger(__name__)
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
@ -38,19 +37,31 @@ USER_CONFIG = "user_config.yaml"
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
return os.path.join(DEFAULT_SAVE_DIR, *paths)
def get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def get_save_path(config_path: str) -> os.PathLike:
def get_arg_save_path(config_path: str) -> os.PathLike:
r"""
Gets the path to saved arguments.
"""
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
def load_config() -> Dict[str, Any]:
r"""
Loads user config if exists.
"""
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return safe_load(f)
@ -59,6 +70,9 @@ def load_config() -> Dict[str, Any]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r"""
Saves user config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
@ -69,23 +83,10 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
safe_dump(user_config, f)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
try:
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_save_path(config_path))
def get_model_path(model_name: str) -> str:
def get_model_path(model_name: str) -> Optional[str]:
r"""
Gets the model path according to the model name.
"""
user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
@ -99,40 +100,71 @@ def get_model_path(model_name: str) -> str:
def get_prefix(model_name: str) -> str:
r"""
Gets the prefix of the model name to obtain the model family.
"""
return model_name.split("-")[0]
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
visual (bool)
"""
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
r"""
Gets the LoRA modules of this model.
"""
return DEFAULT_MODULE.get(get_prefix(model_name), "all")
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat model.
"""
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"
def get_visual(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
return get_prefix(model_name) in VISION_MODELS
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value=[], choices=[], interactive=False)
adapters = []
if model_name and finetuning_type == "lora":
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r"""
Lists all available checkpoints.
"""
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
adapters.append(adapter)
return gr.Dropdown(value=[], choices=adapters, interactive=True)
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
r"""
Loads dataset_info.json.
"""
if dataset_dir == "ONLINE":
logger.info("dataset_dir is ONLINE, using online dataset.")
return {}
@ -145,12 +177,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(value=[], choices=datasets)
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))
return gr.Dropdown(choices=datasets)

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset
from ..common import DEFAULT_DATA_DIR, list_datasets
from .data import create_preview_box
@ -74,6 +74,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False)
return elem_dict

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Dict, Generator, List
from typing import TYPE_CHECKING, Dict, Generator, List, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
@ -24,8 +25,8 @@ def save_model(
lang: str,
model_name: str,
model_path: str,
adapter_path: List[str],
finetuning_type: str,
checkpoint_path: Union[str, List[str]],
template: str,
visual_inputs: bool,
export_size: int,
@ -45,9 +46,9 @@ def save_model(
error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path:
elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list):
error = ALERTS["err_gptq_lora"][lang]
if error:
@ -55,16 +56,8 @@ def save_model(
yield error
return
if adapter_path:
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
args = dict(
model_name_or_path=model_path,
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
@ -77,6 +70,14 @@ def save_model(
export_legacy_format=export_legacy_format,
)
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
yield ALERTS["info_exporting"][lang]
export_model(args)
torch_gc()
@ -86,7 +87,7 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox()
@ -104,8 +105,8 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.lang"),
engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.model_path"),
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.checkpoint_path"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..common import get_model_info, list_checkpoints, save_config
from ..utils import can_quantize
@ -25,8 +25,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
refresh_btn = gr.Button(scale=1)
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
@ -36,27 +35,17 @@ def create_top() -> Dict[str, "Component"]:
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
return dict(
lang=lang,
model_name=model_name,
model_path=model_path,
finetuning_type=finetuning_type,
adapter_path=adapter_path,
refresh_btn=refresh_btn,
checkpoint_path=checkpoint_path,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,

View File

@ -5,8 +5,9 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
from ..utils import change_stage, check_output_dir, list_output_dirs
from .data import create_preview_box
if is_gradio_available():
@ -256,11 +257,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
initial_dir = gr.Textbox(visible=False, interactive=False)
output_dir = gr.Dropdown(allow_custom_value=True)
config_path = gr.Textbox()
with gr.Row():
device_count = gr.Textbox(value=str(get_device_count()), interactive=False)
device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
ds_offload = gr.Checkbox()
@ -282,6 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
arg_load_btn=arg_load_btn,
start_btn=start_btn,
stop_btn=stop_btn,
initial_dir=initial_dir,
output_dir=output_dir,
config_path=config_path,
device_count=device_count,
@ -295,24 +298,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
output_elems = [output_box, progress_bar, loss_viewer]
lang = engine.manager.get_elem_by_id("top.lang")
model_name = engine.manager.get_elem_by_id("top.model_name")
finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type")
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
output_dir.change(
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
return elem_dict

View File

@ -1,11 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
from .common import load_config
from .locales import LOCALES
from .manager import Manager
from .runner import Runner
from .utils import get_time, save_ds_config
from .utils import create_ds_config, get_time
if TYPE_CHECKING:
@ -20,7 +20,7 @@ class Engine:
self.runner = Runner(self.manager, demo_mode)
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
if not demo_mode:
save_ds_config()
create_ds_config()
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r"""
@ -40,16 +40,15 @@ class Engine:
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset().choices}
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
current_time = get_time()
init_dict["train.initial_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
init_dict["infer.image_box"] = {"visible": False}
if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._update_component(init_dict)

View File

@ -46,26 +46,15 @@ LOCALES = {
"label": "微调方法",
},
},
"adapter_path": {
"checkpoint_path": {
"en": {
"label": "Adapter path",
"label": "Checkpoint path",
},
"ru": {
"label": "Путь к адаптеру",
"label": "Путь контрольной точки",
},
"zh": {
"label": "适配器路径",
},
},
"refresh_btn": {
"en": {
"value": "Refresh adapters",
},
"ru": {
"value": "Обновить адаптеры",
},
"zh": {
"value": "刷新适配器",
"label": "检查点路径",
},
},
"advanced_tab": {
@ -1531,6 +1520,11 @@ ALERTS = {
"ru": "Среда CUDA не обнаружена.",
"zh": "未检测到 CUDA 环境。",
},
"warn_output_dir_exists": {
"en": "Output dir already exists, will resume training from here.",
"ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.",
"zh": "输出目录已存在,将从该断点恢复训练。",
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"ru": "Прервано, ожидание завершения...",

View File

@ -55,7 +55,7 @@ class Manager:
self._id_to_elem["top.model_name"],
self._id_to_elem["top.model_path"],
self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.adapter_path"],
self._id_to_elem["top.checkpoint_path"],
self._id_to_elem["top.quantization_bit"],
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],

View File

@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import TRAINING_STAGES
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available():
@ -85,26 +85,16 @@ class Runner:
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"),
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@ -134,13 +124,23 @@ class Runner:
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
plot_loss=True,
ddp_timeout=180000000,
)
# checkpoints
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# freeze config
if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
@ -156,7 +156,7 @@ class Runner:
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["lora_target"] = get("train.lora_target") or get_module(model_name)
args["additional_target"] = get("train.additional_target") or None
if args["use_llama_pro"]:
@ -164,13 +164,14 @@ class Runner:
# rlhf config
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("train.reward_model")
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
if finetuning_type in PEFT_METHODS:
args["reward_model"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
)
else:
args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))
args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0
@ -211,25 +212,15 @@ class Runner:
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"),
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@ -245,7 +236,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
)
if get("eval.predict"):
@ -253,6 +244,14 @@ class Runner:
else:
args["do_eval"] = True
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
@ -296,9 +295,7 @@ class Runner:
self.running = True
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir)
@ -356,7 +353,7 @@ class Runner:
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
config_path = data[self.manager.get_elem_by_id("train.config_path")]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids:

View File

@ -3,12 +3,13 @@ import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from .common import DEFAULT_CACHE_DIR
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
from .locales import ALERTS
@ -17,13 +18,26 @@ if is_gradio_available():
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
r"""
Judges if the quantization is available in this finetuning type.
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
@ -38,11 +52,17 @@ def check_json_schema(text: str, lang: str) -> None:
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
@ -52,17 +72,39 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)
def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
def get_time() -> str:
r"""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
@ -96,17 +138,56 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
return running_log, running_progress, running_loss
def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads saved arguments.
"""
try:
with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_ds_config() -> None:
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
r"""
Saves arguments.
"""
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_arg_save_path(config_path))
def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [initial_dir]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
r"""
Check if output dir exists.
"""
if os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",