diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 76f07a42..a07c7059 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -5,7 +5,7 @@ import signal import time from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional import transformers from transformers import TrainerCallback @@ -38,8 +38,20 @@ class FixValueHeadModelCallback(TrainerCallback): class LogCallback(TrainerCallback): def __init__(self, output_dir: str) -> None: + r""" + Initializes a callback for logging training and evaluation status. + """ + """ Progress """ + self.start_time = 0 + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + self.thread_pool: Optional["ThreadPoolExecutor"] = None + """ Status """ self.aborted = False self.do_train = False + """ Web UI """ self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0"))) if self.webui_mode: signal.signal(signal.SIGABRT, self._set_abort) @@ -66,6 +78,19 @@ class LogCallback(TrainerCallback): self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time))) + def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: + with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + + def _create_thread_pool(self, output_dir: str) -> None: + os.makedirs(output_dir, exist_ok=True) + self.thread_pool = ThreadPoolExecutor(max_workers=1) + + def _close_thread_pool(self) -> None: + if self.thread_pool is not None: + self.thread_pool.shutdown(wait=True) + self.thread_pool = None + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the beginning of training. @@ -73,8 +98,7 @@ class LogCallback(TrainerCallback): if args.should_save: self.do_train = True self._reset(max_steps=state.max_steps) - os.makedirs(args.output_dir, exist_ok=True) - self.thread_pool = ThreadPoolExecutor(max_workers=1) + self._create_thread_pool(output_dir=args.output_dir) if ( args.should_save @@ -84,6 +108,12 @@ class LogCallback(TrainerCallback): logger.warning("Previous trainer log in this folder will be deleted.") os.remove(os.path.join(args.output_dir, TRAINER_LOG)) + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + self._close_thread_pool() + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of an substep during gradient accumulation. @@ -103,31 +133,19 @@ class LogCallback(TrainerCallback): control.should_epoch_stop = True control.should_training_stop = True - def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" - Event called at the end of training. + Event called after an evaluation phase. """ - self.thread_pool.shutdown(wait=True) - self.thread_pool = None + self._close_thread_pool() - def on_prediction_step( - self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs - ): + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" - Event called after a prediction step. + Event called after a successful prediction. """ - eval_dataloader = kwargs.pop("eval_dataloader", None) - if args.should_save and has_length(eval_dataloader) and not self.do_train: - if self.max_steps == 0: - self.max_steps = len(eval_dataloader) + self._close_thread_pool() - self._timing(cur_steps=self.cur_steps + 1) - - def _write_log(self, output_dir: str, logs: Dict[str, Any]): - with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: - f.write(json.dumps(logs) + "\n") - - def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called after logging the last logs. """ @@ -158,3 +176,26 @@ class LogCallback(TrainerCallback): if self.thread_pool is not None: self.thread_pool.submit(self._write_log, args.output_dir, logs) + + def on_prediction_step( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ): + r""" + Event called after a prediction step. + """ + eval_dataloader = kwargs.pop("eval_dataloader", None) + if args.should_save and has_length(eval_dataloader) and not self.do_train: + if self.max_steps == 0: + self._reset(max_steps=len(eval_dataloader)) + self._create_thread_pool(output_dir=args.output_dir) + + self._timing(cur_steps=self.cur_steps + 1) + if self.cur_steps % 5 == 0 and self.thread_pool is not None: + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + self.thread_pool.submit(self._write_log, args.output_dir, logs) diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index a33e3db7..d569f1fa 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -17,6 +17,7 @@ from ..extras.constants import ( TRAINING_STAGES, DownloadSource, ) +from ..extras.logging import get_logger from ..extras.misc import use_modelscope from ..extras.packages import is_gradio_available @@ -25,6 +26,9 @@ if is_gradio_available(): import gradio as gr +logger = get_logger(__name__) + + ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} DEFAULT_CACHE_DIR = "cache" DEFAULT_CONFIG_DIR = "config" @@ -128,11 +132,15 @@ def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown": def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: + if dataset_dir == "ONLINE": + logger.info("dataset_dir is ONLINE, using online dataset.") + return {} + try: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: return json.load(f) except Exception as err: - print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err))) + logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err))) return {} diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 3910a746..222f9314 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -21,16 +21,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) - dataset = gr.Dropdown(multiselect=True, scale=4) + dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) preview_elems = create_preview_box(dataset_dir, dataset) input_elems.update({dataset_dir, dataset}) elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) with gr.Row(): - cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) + cutoff_len = gr.Slider(value=1024, minimum=4, maximum=65536, step=1) max_samples = gr.Textbox(value="100000") - batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) + batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1) predict = gr.Checkbox(value=True) input_elems.update({cutoff_len, max_samples, batch_size, predict}) @@ -48,30 +48,27 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): cmd_preview_btn = gr.Button() start_btn = gr.Button(variant="primary") - stop_btn = gr.Button(variant="stop") with gr.Row(): resume_btn = gr.Checkbox(visible=False, interactive=False) - process_bar = gr.Slider(visible=False, interactive=False) + progress_bar = gr.Slider(visible=False, interactive=False) with gr.Row(): output_box = gr.Markdown() - output_elems = [output_box, process_bar] + output_elems = [output_box, progress_bar] elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, - stop_btn=stop_btn, resume_btn=resume_btn, - process_bar=process_bar, + progress_bar=progress_bar, output_box=output_box, ) ) cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None) start_btn.click(engine.runner.run_eval, input_elems, output_elems) - stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index c709b916..857c56ac 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -27,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 ) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) - dataset = gr.Dropdown(multiselect=True, scale=4) + dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) preview_elems = create_preview_box(dataset_dir, dataset) input_elems.update({training_stage, dataset_dir, dataset}) @@ -52,7 +52,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) with gr.Row(): - cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1) + cutoff_len = gr.Slider(value=1024, minimum=4, maximum=65536, step=1) batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1) gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index b04c9b00..59515a62 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -299,12 +299,12 @@ class Runner: progress_bar: gr.Slider(visible=False), } else: - running_log, running_progress, running_loss = get_trainer_info(output_path) + running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train) return_dict = { output_box: running_log, progress_bar: running_progress, } - if self.do_train and running_loss is not None: + if running_loss is not None: return_dict[loss_viewer] = running_loss yield return_dict diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index c8729d36..1f2b0591 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -63,7 +63,7 @@ def get_time() -> str: return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") -def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: +def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: running_log = "" running_progress = gr.Slider(visible=False) running_loss = None @@ -91,7 +91,7 @@ def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Option ) running_progress = gr.Slider(label=label, value=percentage, visible=True) - if is_matplotlib_available(): + if do_train and is_matplotlib_available(): running_loss = gr.Plot(gen_loss_plot(trainer_log)) return running_log, running_progress, running_loss