From 12d8a8633f1d8db8eb72223f69c074d98af16e01 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 18 Jul 2023 15:35:17 +0800 Subject: [PATCH] update webUI, fix #179 --- src/llmtuner/tuner/ppo/trainer.py | 2 +- src/llmtuner/webui/chat.py | 17 ++- src/llmtuner/webui/components/eval.py | 25 ++-- src/llmtuner/webui/components/infer.py | 12 +- src/llmtuner/webui/components/sft.py | 46 ++++--- src/llmtuner/webui/components/top.py | 25 +++- src/llmtuner/webui/locales.py | 174 +++++++++++++------------ src/llmtuner/webui/runner.py | 83 ++++++++---- src/llmtuner/webui/utils.py | 17 ++- 9 files changed, 247 insertions(+), 154 deletions(-) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 7a05dadb..0d84af6e 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -108,7 +108,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): replace_model(unwrapped_model, target="reward") with torch.no_grad(): _, _, values = self.model(**self.prepare_model_inputs(queries, responses)) - rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type + rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type replace_model(unwrapped_model, target="default") # Run PPO step diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index 71018c31..c889eca5 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -17,8 +17,14 @@ class WebChatModel(ChatModel): self.generating_args = GeneratingArguments() def load_model( - self, lang: str, model_name: str, checkpoints: list, - finetuning_type: str, template: str, quantization_bit: str + self, + lang: str, + model_name: str, + checkpoints: List[str], + finetuning_type: str, + quantization_bit: str, + template: str, + source_prefix: str ): if self.model is not None: yield ALERTS["err_exists"][lang] @@ -43,10 +49,11 @@ class WebChatModel(ChatModel): yield ALERTS["info_loading"][lang] args = dict( model_name_or_path=model_name_or_path, - finetuning_type=finetuning_type, - prompt_template=template, checkpoint_dir=checkpoint_dir, - quantization_bit=int(quantization_bit) if quantization_bit else None + finetuning_type=finetuning_type, + quantization_bit=int(quantization_bit) if quantization_bit else None, + prompt_template=template, + source_prefix=source_prefix ) super().__init__(*get_infer_args(args)) diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 67407cbc..82d7f658 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -10,8 +10,8 @@ from llmtuner.webui.utils import can_preview, get_preview def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: with gr.Row(): - dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) - dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) + dataset = gr.Dropdown(multiselect=True, scale=4) preview_btn = gr.Button(interactive=False, scale=1) preview_box, preview_count, preview_samples, close_btn = create_preview_box() @@ -21,9 +21,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) with gr.Row(): - max_samples = gr.Textbox(value="100000", interactive=True) - batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True) - quantization_bit = gr.Dropdown([8, 4]) + max_samples = gr.Textbox(value="100000") + batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1) predict = gr.Checkbox(value=True) with gr.Row(): @@ -35,9 +34,18 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str start_btn.click( runner.run_eval, [ - top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], - top_elems["finetuning_type"], top_elems["template"], - dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + top_elems["quantization_bit"], + top_elems["template"], + top_elems["source_prefix"], + dataset_dir, + dataset, + max_samples, + batch_size, + predict ], [output_box] ) @@ -52,7 +60,6 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str close_btn=close_btn, max_samples=max_samples, batch_size=batch_size, - quantization_bit=quantization_bit, predict=predict, start_btn=start_btn, stop_btn=stop_btn, diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 831dbea5..5508a163 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -11,7 +11,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: with gr.Row(): load_btn = gr.Button() unload_btn = gr.Button() - quantization_bit = gr.Dropdown([8, 4]) info_box = gr.Markdown() @@ -21,9 +20,13 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: load_btn.click( chat_model.load_model, [ - top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], - top_elems["finetuning_type"], top_elems["template"], - quantization_bit + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + top_elems["quantization_bit"], + top_elems["template"], + top_elems["source_prefix"] ], [info_box] ).then( @@ -39,7 +42,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: ) return dict( - quantization_bit=quantization_bit, info_box=info_box, load_btn=load_btn, unload_btn=unload_btn, diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/sft.py index d6de0c5d..b0e785da 100644 --- a/src/llmtuner/webui/components/sft.py +++ b/src/llmtuner/webui/components/sft.py @@ -12,8 +12,8 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: with gr.Row(): - dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) - dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) + dataset = gr.Dropdown(multiselect=True, scale=4) preview_btn = gr.Button(interactive=False, scale=1) preview_box, preview_count, preview_samples, close_btn = create_preview_box() @@ -23,22 +23,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) with gr.Row(): - learning_rate = gr.Textbox(value="5e-5", interactive=True) - num_train_epochs = gr.Textbox(value="3.0", interactive=True) - max_samples = gr.Textbox(value="100000", interactive=True) - quantization_bit = gr.Dropdown([8, 4]) + learning_rate = gr.Textbox(value="5e-5") + num_train_epochs = gr.Textbox(value="3.0") + max_samples = gr.Textbox(value="100000") with gr.Row(): - batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True) - gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True) + batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1) + gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1) lr_scheduler_type = gr.Dropdown( - value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True + value="cosine", choices=[scheduler.value for scheduler in SchedulerType] ) fp16 = gr.Checkbox(value=True) with gr.Row(): - logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True) - save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True) + logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) + save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10) with gr.Row(): start_btn = gr.Button() @@ -55,11 +54,25 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, start_btn.click( runner.run_train, [ - top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], - top_elems["finetuning_type"], top_elems["template"], - dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, - fp16, quantization_bit, batch_size, gradient_accumulation_steps, - lr_scheduler_type, logging_steps, save_steps, output_dir + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + top_elems["quantization_bit"], + top_elems["template"], + top_elems["source_prefix"], + dataset_dir, + dataset, + learning_rate, + num_train_epochs, + max_samples, + batch_size, + gradient_accumulation_steps, + lr_scheduler_type, + fp16, + logging_steps, + save_steps, + output_dir ], [output_box] ) @@ -79,7 +92,6 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, learning_rate=learning_rate, num_train_epochs=num_train_epochs, max_samples=max_samples, - quantization_bit=quantization_bit, batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, lr_scheduler_type=lr_scheduler_type, diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 8efb1d8d..a8f7e542 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -6,29 +6,40 @@ from gradio.components import Component from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates from llmtuner.webui.common import list_checkpoint, get_model_path, save_config +from llmtuner.webui.utils import can_quantize def create_top() -> Dict[str, Component]: available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] with gr.Row(): - lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1) + lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1) model_name = gr.Dropdown(choices=available_models, scale=3) model_path = gr.Textbox(scale=3) with gr.Row(): - finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) - template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1) - checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4) + finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1) + checkpoints = gr.Dropdown(multiselect=True, scale=5) refresh_btn = gr.Button(scale=1) + with gr.Row(): + quantization_bit = gr.Dropdown([8, 4], scale=1) + template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2) + source_prefix = gr.Textbox(scale=4) + model_name.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] ).then( get_model_path, [model_name], [model_path] ) # do not save config since the below line will save model_path.change(save_config, [model_name, model_path]) - finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints]) + + finetuning_type.change( + list_checkpoint, [model_name, finetuning_type], [checkpoints] + ).then( + can_quantize, [finetuning_type], [quantization_bit] + ) + refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) return dict( @@ -38,5 +49,7 @@ def create_top() -> Dict[str, Component]: finetuning_type=finetuning_type, template=template, checkpoints=checkpoints, - refresh_btn=refresh_btn + refresh_btn=refresh_btn, + quantization_bit=quantization_bit, + source_prefix=source_prefix ) diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 2bc9b3d0..817fd316 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -25,6 +25,14 @@ LOCALES = { "info": "本地模型的文件路径或 Hugging Face 的模型标识符。" } }, + "finetuning_type": { + "en": { + "label": "Finetuning method" + }, + "zh": { + "label": "微调方法" + } + }, "checkpoints": { "en": { "label": "Checkpoints" @@ -33,14 +41,6 @@ LOCALES = { "label": "模型断点" } }, - "template": { - "en": { - "label": "Prompt template" - }, - "zh": { - "label": "提示模板" - } - }, "refresh_btn": { "en": { "value": "Refresh checkpoints" @@ -49,6 +49,36 @@ LOCALES = { "value": "刷新断点" } }, + "quantization_bit": { + "en": { + "label": "Quantization bit (optional)", + "info": "Enable 4/8-bit model quantization." + }, + "zh": { + "label": "量化等级(非必填)", + "info": "启用 4/8 比特模型量化。" + } + }, + "template": { + "en": { + "label": "Prompt template", + "info": "The template used in constructing prompts." + }, + "zh": { + "label": "提示模板", + "info": "构建提示词时使用的模板" + } + }, + "source_prefix": { + "en": { + "label": "Source prefix (optional)", + "info": "A sequence used as the prefix of each samples." + }, + "zh": { + "label": "前缀序列(非必填)", + "info": "作为每个输入样本前缀的序列" + } + }, "dataset_dir": { "en": { "label": "Data dir", @@ -99,68 +129,6 @@ LOCALES = { "value": "关闭" } }, - "max_samples": { - "en": { - "label": "Max samples", - "info": "Maximum samples per dataset." - }, - "zh": { - "label": "最大样本数", - "info": "每个数据集最多使用的样本数。" - } - }, - "batch_size": { - "en": { - "label": "Batch size", - "info": "Number of samples to process per GPU." - }, - "zh":{ - "label": "批处理大小", - "info": "每块 GPU 上处理的样本数量。" - } - }, - "quantization_bit": { - "en": { - "label": "Quantization bit", - "info": "Enable 4/8-bit model quantization." - }, - "zh": { - "label": "量化", - "info": "启用 4/8 比特模型量化。" - } - }, - "start_btn": { - "en": { - "value": "Start" - }, - "zh": { - "value": "开始" - } - }, - "stop_btn": { - "en": { - "value": "Abort" - }, - "zh": { - "value": "中断" - } - }, - "output_box": { - "en": { - "value": "Ready." - }, - "zh": { - "value": "准备就绪。" - } - }, - "finetuning_type": { - "en": { - "label": "Finetuning method" - }, - "zh": { - "label": "微调方法" - } - }, "learning_rate": { "en": { "label": "Learning rate", @@ -181,6 +149,26 @@ LOCALES = { "info": "需要执行的训练总轮数。" } }, + "max_samples": { + "en": { + "label": "Max samples", + "info": "Maximum samples per dataset." + }, + "zh": { + "label": "最大样本数", + "info": "每个数据集最多使用的样本数。" + } + }, + "batch_size": { + "en": { + "label": "Batch size", + "info": "Number of samples to process per GPU." + }, + "zh":{ + "label": "批处理大小", + "info": "每块 GPU 上处理的样本数量。" + } + }, "gradient_accumulation_steps": { "en": { "label": "Gradient accumulation", @@ -231,6 +219,22 @@ LOCALES = { "info": "每两次断点保存间的更新步数。" } }, + "start_btn": { + "en": { + "value": "Start" + }, + "zh": { + "value": "开始" + } + }, + "stop_btn": { + "en": { + "value": "Abort" + }, + "zh": { + "value": "中断" + } + }, "output_dir": { "en": { "label": "Checkpoint name", @@ -241,6 +245,14 @@ LOCALES = { "info": "保存模型断点的文件夹名称。" } }, + "output_box": { + "en": { + "value": "Ready." + }, + "zh": { + "value": "准备就绪。" + } + }, "loss_viewer": { "en": { "label": "Loss" @@ -257,14 +269,6 @@ LOCALES = { "label": "保存预测结果" } }, - "info_box": { - "en": { - "value": "Model unloaded, please load a model first." - }, - "zh": { - "value": "模型未加载,请先加载模型。" - } - }, "load_btn": { "en": { "value": "Load model" @@ -281,6 +285,14 @@ LOCALES = { "value": "卸载模型" } }, + "info_box": { + "en": { + "value": "Model unloaded, please load a model first." + }, + "zh": { + "value": "模型未加载,请先加载模型。" + } + }, "query": { "en": { "placeholder": "Input..." @@ -305,12 +317,12 @@ LOCALES = { "value": "清空历史" } }, - "max_new_tokens": { + "max_length": { "en": { - "label": "Maximum new tokens" + "label": "Maximum length" }, "zh": { - "label": "最大生成长度" + "label": "最大长度" } }, "top_p": { diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 45d8b340..408a46cf 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -3,7 +3,7 @@ import os import threading import time import transformers -from typing import Optional, Tuple +from typing import List, Optional, Tuple from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated @@ -59,10 +59,26 @@ class Runner: return finish_info if finish_info is not None else ALERTS["info_finished"][lang] def run_train( - self, lang, model_name, checkpoints, finetuning_type, template, - dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, - fp16, quantization_bit, batch_size, gradient_accumulation_steps, - lr_scheduler_type, logging_steps, save_steps, output_dir + self, + lang: str, + model_name: str, + checkpoints: List[str], + finetuning_type: str, + quantization_bit: str, + template: str, + source_prefix: str, + dataset_dir: str, + dataset: List[str], + learning_rate: str, + num_train_epochs: str, + max_samples: str, + batch_size: int, + gradient_accumulation_steps: int, + lr_scheduler_type: str, + fp16: bool, + logging_steps: int, + save_steps: int, + output_dir: str ): model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) if error: @@ -79,24 +95,25 @@ class Runner: args = dict( model_name_or_path=model_name_or_path, do_train=True, - finetuning_type=finetuning_type, - lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj", - prompt_template=template, - dataset=",".join(dataset), - dataset_dir=dataset_dir, - max_samples=int(max_samples), - output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir), - checkpoint_dir=checkpoint_dir, overwrite_cache=True, + lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj", + checkpoint_dir=checkpoint_dir, + finetuning_type=finetuning_type, + quantization_bit=int(quantization_bit) if quantization_bit else None, + prompt_template=template, + source_prefix=source_prefix, + dataset_dir=dataset_dir, + dataset=",".join(dataset), + learning_rate=float(learning_rate), + num_train_epochs=float(num_train_epochs), + max_samples=int(max_samples), per_device_train_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, lr_scheduler_type=lr_scheduler_type, + fp16=fp16, logging_steps=logging_steps, save_steps=save_steps, - learning_rate=float(learning_rate), - num_train_epochs=float(num_train_epochs), - fp16=fp16, - quantization_bit=int(quantization_bit) if quantization_bit else None + output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir) ) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) @@ -120,8 +137,19 @@ class Runner: yield self.finalize(lang) def run_eval( - self, lang, model_name, checkpoints, finetuning_type, template, - dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict + self, + lang: str, + model_name: str, + checkpoints: List[str], + finetuning_type: str, + quantization_bit: str, + template: str, + source_prefix: str, + dataset_dir: str, + dataset: List[str], + max_samples: str, + batch_size: int, + predict: bool ): model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) if error: @@ -140,17 +168,18 @@ class Runner: args = dict( model_name_or_path=model_name_or_path, do_eval=True, - finetuning_type=finetuning_type, - prompt_template=template, - dataset=",".join(dataset), - dataset_dir=dataset_dir, - max_samples=int(max_samples), - output_dir=output_dir, - checkpoint_dir=checkpoint_dir, overwrite_cache=True, predict_with_generate=True, + checkpoint_dir=checkpoint_dir, + finetuning_type=finetuning_type, + quantization_bit=int(quantization_bit) if quantization_bit else None, + prompt_template=template, + source_prefix=source_prefix, + dataset_dir=dataset_dir, + dataset=",".join(dataset), + max_samples=int(max_samples), per_device_eval_batch_size=batch_size, - quantization_bit=int(quantization_bit) if quantization_bit else None + output_dir=output_dir ) if predict: diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 506ea4b0..a5b5640f 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -3,7 +3,7 @@ import json import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt -from typing import Tuple +from typing import Any, Dict, Tuple from datetime import datetime from llmtuner.extras.ploting import smooth @@ -23,7 +23,7 @@ def get_time() -> str: return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') -def can_preview(dataset_dir: str, dataset: list) -> dict: +def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: dataset_info = json.load(f) if ( @@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict: return gr.update(interactive=False) -def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]: +def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: dataset_info = json.load(f) data_file = dataset_info[dataset[0]]["file_name"] @@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]: return len(data), data[:2], gr.update(visible=True) +def can_quantize(finetuning_type: str) -> Dict[str, Any]: + if finetuning_type != "lora": + return gr.update(value="", interactive=False) + else: + return gr.update(interactive=True) + + def get_eval_results(path: os.PathLike) -> str: with open(path, "r", encoding="utf-8") as f: result = json.dumps(json.load(f), indent=4) @@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl if log_info.get("loss", None): steps.append(log_info["current_steps"]) losses.append(log_info["loss"]) + + if len(losses) == 0: + return None + ax.plot(steps, losses, alpha=0.4, label="original") ax.plot(steps, smooth(losses), label="smoothed") ax.legend()