From 6bc8e9866d482c945dd98f4e9ab205a7d7270755 Mon Sep 17 00:00:00 2001 From: codemayq Date: Sat, 12 Aug 2023 13:53:55 +0800 Subject: [PATCH] add sft script preview in webui --- src/llmtuner/extras/constants.py | 2 + src/llmtuner/webui/components/sft.py | 46 +++++++++++- src/llmtuner/webui/locales.py | 16 ++++ src/llmtuner/webui/runner.py | 105 +++++++++++++++++++++------ 4 files changed, 145 insertions(+), 24 deletions(-) diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 8ee997bb..4e7101f7 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -1,5 +1,7 @@ IGNORE_INDEX = -100 +SFT_SCRIPT_PREFIX = "CUDA_VISIBLE_DEVICES=0 python " + LOG_FILE_NAME = "trainer_log.jsonl" VALUE_HEAD_FILE_NAME = "value_head.bin" diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/sft.py index 05a6e530..e74ef5cf 100644 --- a/src/llmtuner/webui/components/sft.py +++ b/src/llmtuner/webui/components/sft.py @@ -61,11 +61,15 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ resume_lora_training = gr.Checkbox(value=True, scale=1) with gr.Row(): + preview_script_btn = gr.Button() start_btn = gr.Button() stop_btn = gr.Button() with gr.Row(): with gr.Column(scale=3): + with gr.Box(): + preview_script_box = gr.Textbox() + with gr.Row(): output_dir = gr.Textbox() @@ -78,6 +82,44 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ with gr.Column(scale=1): loss_viewer = gr.Plot() + preview_script_btn.click( + runner.preview_sft_script, + [ + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + top_elems["quantization_bit"], + top_elems["template"], + top_elems["source_prefix"], + dataset_dir, + dataset, + max_source_length, + max_target_length, + learning_rate, + num_train_epochs, + max_samples, + batch_size, + gradient_accumulation_steps, + lr_scheduler_type, + max_grad_norm, + val_size, + logging_steps, + save_steps, + warmup_steps, + compute_type, + padding_side, + lora_rank, + lora_dropout, + lora_target, + resume_lora_training, + output_dir + ], + [ + preview_script_box + ] + ) + start_btn.click( runner.run_train, [ @@ -154,5 +196,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ stop_btn=stop_btn, output_dir=output_dir, output_box=output_box, - loss_viewer=loss_viewer + loss_viewer=loss_viewer, + preview_script_btn=preview_script_btn, + preview_script_box=preview_script_box ) diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 61491ece..c4845735 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -500,6 +500,22 @@ LOCALES = { "zh": { "value": "开始导出" } + }, + "preview_script_btn": { + "en": { + "value": "preview train script" + }, + "zh": { + "value": "预览训练脚本命令" + } + }, + "preview_script_box": { + "en": { + "label": "SFT Script Preview", + }, + "zh": { + "label": "训练命令预览", + } } } diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 1ae92786..d0d12d14 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME from typing import Generator, List, Tuple from llmtuner.extras.callbacks import LogCallback -from llmtuner.extras.constants import DEFAULT_MODULE +from llmtuner.extras.constants import DEFAULT_MODULE, SFT_SCRIPT_PREFIX from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp @@ -100,16 +100,44 @@ class Runner: if error: yield error, gr.update(visible=False) return + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) + args = self._build_args(batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type, + gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank, + lora_target, lr_scheduler_type, max_grad_norm, max_samples, max_source_length, + max_target_length, model_name, model_name_or_path, num_train_epochs, output_dir, + padding_side, quantization_bit, resume_lora_training, save_steps, source_prefix, + template, val_size, warmup_steps) + + run_kwargs = dict(args=args, callbacks=[trainer_callback]) + thread = threading.Thread(target=run_exp, kwargs=run_kwargs) + thread.start() + + while thread.is_alive(): + time.sleep(2) + if self.aborted: + yield ALERTS["info_aborting"][lang], gr.update(visible=False) + else: + yield logger_handler.log, update_process_bar(trainer_callback) + + if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): + finish_info = ALERTS["info_finished"][lang] + else: + finish_info = ALERTS["err_failed"][lang] + + yield self.finalize(lang, finish_info), gr.update(visible=False) + + def _build_args(self, batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type, + gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank, lora_target, + lr_scheduler_type, max_grad_norm, max_samples, max_source_length, max_target_length, model_name, + model_name_or_path, num_train_epochs, output_dir, padding_side, quantization_bit, + resume_lora_training, save_steps, source_prefix, template, val_size, warmup_steps): if checkpoints: checkpoint_dir = ",".join( [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] ) else: checkpoint_dir = None - - output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) - args = dict( stage="sft", model_name_or_path=model_name_or_path, @@ -143,30 +171,12 @@ class Runner: resume_lora_training=resume_lora_training, output_dir=output_dir ) - if val_size > 1e-6: args["val_size"] = val_size args["evaluation_strategy"] = "steps" args["eval_steps"] = save_steps args["load_best_model_at_end"] = True - - run_kwargs = dict(args=args, callbacks=[trainer_callback]) - thread = threading.Thread(target=run_exp, kwargs=run_kwargs) - thread.start() - - while thread.is_alive(): - time.sleep(2) - if self.aborted: - yield ALERTS["info_aborting"][lang], gr.update(visible=False) - else: - yield logger_handler.log, update_process_bar(trainer_callback) - - if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): - finish_info = ALERTS["info_finished"][lang] - else: - finish_info = ALERTS["err_failed"][lang] - - yield self.finalize(lang, finish_info), gr.update(visible=False) + return args def run_eval( self, @@ -240,3 +250,52 @@ class Runner: finish_info = ALERTS["err_failed"][lang] yield self.finalize(lang, finish_info), gr.update(visible=False) + + def preview_sft_script( + self, + lang: str, + model_name: str, + checkpoints: List[str], + finetuning_type: str, + quantization_bit: str, + template: str, + source_prefix: str, + dataset_dir: str, + dataset: List[str], + max_source_length: int, + max_target_length: int, + learning_rate: str, + num_train_epochs: str, + max_samples: str, + batch_size: int, + gradient_accumulation_steps: int, + lr_scheduler_type: str, + max_grad_norm: str, + val_size: float, + logging_steps: int, + save_steps: int, + warmup_steps: int, + compute_type: str, + padding_side: str, + lora_rank: int, + lora_dropout: float, + lora_target: str, + resume_lora_training: bool, + output_dir: str + ): + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) + + args = self._build_args(batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type, + gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank, + lora_target, lr_scheduler_type, max_grad_norm, max_samples, max_source_length, + max_target_length, model_name, model_name_or_path, num_train_epochs, output_dir, + padding_side, quantization_bit, resume_lora_training, save_steps, source_prefix, + template, val_size, warmup_steps) + script_lines = [SFT_SCRIPT_PREFIX] + for param_key, param_value in args.items(): + # filter None + if param_value: + script_lines.append(" --" + param_key + " " + str(param_value) + " ") + script_str = "\\\n".join(script_lines) + return gr.update(value=script_str)