From 5ddcecda50ccff93d51bebc9ac72c2a0dd483e9b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 4 Apr 2024 00:55:31 +0800 Subject: [PATCH] fix bug in latest gradio --- requirements.txt | 2 +- scripts/llamafy_internlm2.py | 114 ------------------------- src/llmtuner/extras/misc.py | 2 +- src/llmtuner/webui/components/eval.py | 6 +- src/llmtuner/webui/components/train.py | 84 ++++++++---------- src/llmtuner/webui/locales.py | 5 ++ src/llmtuner/webui/runner.py | 85 ++++++++++++------ src/llmtuner/webui/utils.py | 17 ++-- 8 files changed, 111 insertions(+), 204 deletions(-) delete mode 100644 scripts/llamafy_internlm2.py diff --git a/requirements.txt b/requirements.txt index fe8ab35c..3928d28d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ datasets>=2.14.3 accelerate>=0.27.2 peft>=0.10.0 trl>=0.8.1 -gradio>4.0.0,<=4.21.0 +gradio>=4.0.0 scipy einops sentencepiece diff --git a/scripts/llamafy_internlm2.py b/scripts/llamafy_internlm2.py deleted file mode 100644 index b6b03e7d..00000000 --- a/scripts/llamafy_internlm2.py +++ /dev/null @@ -1,114 +0,0 @@ -# coding=utf-8 -# Converts the InternLM2 model in the same format as LLaMA2. -# Usage: python llamafy_internlm2.py --input_dir input --output_dir output -# Warning: We have found that the converted model cannot infer correctly. It will be fixed later. - -import json -import os -from collections import OrderedDict -from typing import Any, Dict, Optional - -import fire -import torch -from safetensors.torch import save_file -from tqdm import tqdm -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - WEIGHTS_INDEX_NAME, - WEIGHTS_NAME, - shard_checkpoint, -) - - -CONFIG_NAME = "config.json" - - -def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool): - with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: - internlm2_config_dict: Dict[str, Any] = json.load(f) - - internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict() - for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): - if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): - shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") - internlm2_state_dict.update(shard_weight) - - llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() - for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"): - if "output" in key: - llama2_state_dict[key.replace("output", "lm_head")] = value - elif "tok_embeddings" in key: - llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value - elif "wqkv" in key: - num_q_heads = internlm2_config_dict["num_attention_heads"] - num_kv_heads = internlm2_config_dict["num_key_value_heads"] - q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads - kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads - llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...] - llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[ - q_size : q_size + kv_size, ... - ] - llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...] - elif "wo" in key: - llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value - elif "attention_norm" in key: - llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value - elif "ffn_norm" in key: - llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value - elif "w1" in key: - llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value - elif "w2" in key: - llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value - elif "w3" in key: - llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value - else: - llama2_state_dict[key] = value - - weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME - shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) - - for shard_file, shard in tqdm(shards.items(), desc="Save weights"): - if save_safetensors: - save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"}) - else: - torch.save(shard, os.path.join(output_dir, shard_file)) - - if index is None: - print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) - else: - index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME - with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: - json.dump(index, f, indent=2, sort_keys=True) - print("Model weights saved in {}".format(output_dir)) - - -def save_config(input_dir: str, output_dir: str): - with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: - llama2_config_dict: Dict[str, Any] = json.load(f) - - llama2_config_dict["architectures"] = ["LlamaForCausalLM"] - llama2_config_dict.pop("auto_map", None) - llama2_config_dict.pop("bias", None) - llama2_config_dict.pop("rope_scaling", None) - llama2_config_dict["model_type"] = "llama" - - with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: - json.dump(llama2_config_dict, f, indent=2) - print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) - - -def llamafy_internlm2( - input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False -): - try: - os.makedirs(output_dir, exist_ok=False) - except Exception as e: - raise print("Output dir already exists", e) - - save_weight(input_dir, output_dir, shard_size, save_safetensors) - save_config(input_dir, output_dir) - - -if __name__ == "__main__": - fire.Fire(llamafy_internlm2) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index a696b315..2093d7ea 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -66,7 +66,7 @@ def check_dependencies() -> None: require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") - require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0") + require_version("gradio>=4.0.0", "To fix: pip install gradio>=4.0.0") def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index a1dae98c..87611da5 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -21,8 +21,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: dataset = gr.Dropdown(multiselect=True, scale=4) preview_elems = create_preview_box(dataset_dir, dataset) - dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) - input_elems.update({dataset_dir, dataset}) elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) @@ -50,7 +48,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: stop_btn = gr.Button(variant="stop") with gr.Row(): - resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) + resume_btn = gr.Checkbox(visible=False, interactive=False) process_bar = gr.Slider(visible=False, interactive=False) with gr.Row(): @@ -73,4 +71,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) + dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) + return elem_dict diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 1c425d51..4f108db0 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -6,7 +6,6 @@ from transformers.trainer_utils import SchedulerType from ...extras.constants import TRAINING_STAGES from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..components.data import create_preview_box -from ..utils import gen_plot if TYPE_CHECKING: @@ -24,7 +23,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 ) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) - dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True) + dataset = gr.Dropdown(multiselect=True, scale=4, allow_custom_value=True) preview_elems = create_preview_box(dataset_dir, dataset) input_elems.update({training_stage, dataset_dir, dataset}) @@ -121,8 +120,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as freeze_tab: with gr.Row(): - num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2) - name_module_trainable = gr.Textbox(value="all", scale=3) + num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1) + name_module_trainable = gr.Textbox(value="all") input_elems.update({num_layer_trainable, name_module_trainable}) elem_dict.update( @@ -140,8 +139,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter = gr.Checkbox() with gr.Row(): - use_rslora = gr.Checkbox(scale=1) - use_dora = gr.Checkbox(scale=1) + with gr.Column(scale=1): + use_rslora = gr.Checkbox() + use_dora = gr.Checkbox() + lora_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2) @@ -175,10 +176,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as rlhf_tab: with gr.Row(): - dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) - dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1) - orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) - reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2) + dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01) + dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01) + orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01) + reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True) input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) elem_dict.update( @@ -187,11 +188,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as galore_tab: with gr.Row(): - use_galore = gr.Checkbox(scale=1) - galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2) - galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2) - galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2) - galore_target = gr.Textbox(value="all", scale=3) + use_galore = gr.Checkbox() + galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1) + galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1) + galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01) + galore_target = gr.Textbox(value="all") input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) elem_dict.update( @@ -228,29 +229,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Column(scale=1): loss_viewer = gr.Plot() - input_elems.update({output_dir, config_path}) - output_elems = [output_box, process_bar] - - cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) - arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) - arg_load_btn.click( - engine.runner.load_args, - [engine.manager.get_elem_by_id("top.lang"), config_path], - list(input_elems), - concurrency_limit=None, - ) - start_btn.click(engine.runner.run_train, input_elems, output_elems) - stop_btn.click(engine.runner.set_abort) - resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) - - dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) - training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then( - list_adapters, - [engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")], - [reward_model], - queue=False, - ).then(autoset_packing, [training_stage], [packing], queue=False) - elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, @@ -267,15 +245,27 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) ) - output_box.change( - gen_plot, - [ - engine.manager.get_elem_by_id("top.model_name"), - engine.manager.get_elem_by_id("top.finetuning_type"), - output_dir, - ], - loss_viewer, - queue=False, + input_elems.update({output_dir, config_path}) + output_elems = [output_box, process_bar, loss_viewer] + + cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) + arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) + arg_load_btn.click( + engine.runner.load_args, + [engine.manager.get_elem_by_id("top.lang"), config_path], + list(input_elems) + [output_box], + concurrency_limit=None, ) + start_btn.click(engine.runner.run_train, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort) + resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) + + dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) + training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then( + list_adapters, + [engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")], + [reward_model], + queue=False, + ).then(autoset_packing, [training_stage], [packing], queue=False) return elem_dict diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index b7319fd4..c3111e8f 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -1344,6 +1344,11 @@ ALERTS = { "ru": "Аргументы были сохранены по адресу: ", "zh": "训练参数已保存至:", }, + "info_config_loaded": { + "en": "Arguments have been restored.", + "ru": "Аргументы были восстановлены.", + "zh": "训练参数已载入。", + }, "info_loading": { "en": "Loading model...", "ru": "Загрузка модели...", diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index dae7daf8..2d3ef80f 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -2,7 +2,7 @@ import logging import os import time from threading import Thread -from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator import gradio as gr import transformers @@ -17,7 +17,7 @@ from ..extras.misc import get_device_count, torch_gc from ..train import run_exp from .common import get_module, get_save_dir, load_args, load_config, save_args from .locales import ALERTS -from .utils import gen_cmd, get_eval_results, update_process_bar +from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar if TYPE_CHECKING: @@ -239,20 +239,22 @@ class Runner: return args - def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]: + def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]: + output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=True) if error: gr.Warning(error) - yield error, gr.Slider(visible=False) + yield {output_box: error} else: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) - yield gen_cmd(args), gr.Slider(visible=False) + yield {output_box: gen_cmd(args)} - def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]: + def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]: + output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=False) if error: gr.Warning(error) - yield error, gr.Slider(visible=False) + yield {output_box: error} else: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) @@ -261,54 +263,80 @@ class Runner: self.thread.start() yield from self.monitor() - def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: + def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: yield from self._preview(data, do_train=True) - def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: + def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: yield from self._preview(data, do_train=False) - def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: + def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: yield from self._launch(data, do_train=True) - def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: + def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: yield from self._launch(data, do_train=False) - def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]: + def monitor(self) -> Generator[Dict[Component, Any], None, None]: get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] self.running = True + lang = get("top.lang") - output_dir = get_save_dir( - get("top.model_name"), - get("top.finetuning_type"), - get("{}.output_dir".format("train" if self.do_train else "eval")), - ) + model_name = get("top.model_name") + finetuning_type = get("top.finetuning_type") + output_dir = get("{}.output_dir".format("train" if self.do_train else "eval")) + output_path = get_save_dir(model_name, finetuning_type, output_dir) + + output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval")) + process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval")) + loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None while self.thread is not None and self.thread.is_alive(): if self.aborted: - yield ALERTS["info_aborting"][lang], gr.Slider(visible=False) + yield { + output_box: ALERTS["info_aborting"][lang], + process_bar: gr.Slider(visible=False), + } else: - yield self.logger_handler.log, update_process_bar(self.trainer_callback) + return_dict = { + output_box: self.logger_handler.log, + process_bar: update_process_bar(self.trainer_callback), + } + if self.do_train: + plot = gen_plot(output_path) + if plot is not None: + return_dict[loss_viewer] = plot + + yield return_dict time.sleep(2) if self.do_train: - if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): + if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang] else: - if os.path.exists(os.path.join(output_dir, "all_results.json")): - finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) + if os.path.exists(os.path.join(output_path, "all_results.json")): + finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang] - yield self._finalize(lang, finish_info), gr.Slider(visible=False) + return_dict = { + output_box: self._finalize(lang, finish_info), + process_bar: gr.Slider(visible=False), + } + if self.do_train: + plot = gen_plot(output_path) + if plot is not None: + return_dict[loss_viewer] = plot - def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]: + yield return_dict + + def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]: + output_box = self.manager.get_elem_by_id("train.output_box") error = self._initialize(data, do_train=True, from_preview=True) if error: gr.Warning(error) - return error, gr.Slider(visible=False) + return {output_box: error} config_dict: Dict[str, Any] = {} lang = data[self.manager.get_elem_by_id("top.lang")] @@ -320,15 +348,16 @@ class Runner: config_dict[elem_id] = value save_path = save_args(config_path, config_dict) - return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False) + return {output_box: ALERTS["info_config_saved"][lang] + save_path} def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]: + output_box = self.manager.get_elem_by_id("train.output_box") config_dict = load_args(config_path) if config_dict is None: gr.Warning(ALERTS["err_config_not_found"][lang]) - return {self.manager.get_elem_by_id("top.lang"): lang} + return {output_box: ALERTS["err_config_not_found"][lang]} - output_dict: Dict["Component", Any] = {} + output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]} for elem_id, value in config_dict.items(): output_dict[self.manager.get_elem_by_id(elem_id)] = value diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 275b6573..d96b1f6b 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -1,13 +1,12 @@ import json import os from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional import gradio as gr from ..extras.packages import is_matplotlib_available from ..extras.ploting import smooth -from .common import get_save_dir from .locales import ALERTS @@ -36,7 +35,7 @@ def get_time() -> str: def can_quantize(finetuning_type: str) -> "gr.Dropdown": if finetuning_type != "lora": - return gr.Dropdown(value="None", interactive=False) + return gr.Dropdown(value="none", interactive=False) else: return gr.Dropdown(interactive=True) @@ -74,11 +73,9 @@ def get_eval_results(path: os.PathLike) -> str: return "```json\n{}\n```\n".format(result) -def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure": - if not base_model: - return - log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") - if not os.path.isfile(log_file): +def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]: + log_file = os.path.join(output_path, "trainer_log.jsonl") + if not os.path.isfile(log_file) or not is_matplotlib_available(): return plt.close("all") @@ -88,13 +85,13 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot steps, losses = [], [] with open(log_file, "r", encoding="utf-8") as f: for line in f: - log_info = json.loads(line) + log_info: Dict[str, Any] = json.loads(line) if log_info.get("loss", None): steps.append(log_info["current_steps"]) losses.append(log_info["loss"]) if len(losses) == 0: - return None + return ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")