fix bug in latest gradio

This commit is contained in:
hiyouga 2024-04-04 00:55:31 +08:00
parent 7f6e412604
commit 5ddcecda50
8 changed files with 111 additions and 204 deletions

View File

@ -4,7 +4,7 @@ datasets>=2.14.3
accelerate>=0.27.2
peft>=0.10.0
trl>=0.8.1
gradio>4.0.0,<=4.21.0
gradio>=4.0.0
scipy
einops
sentencepiece

View File

@ -1,114 +0,0 @@
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f)
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
internlm2_state_dict.update(shard_weight)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
if "output" in key:
llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "wqkv" in key:
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key:
llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
elif "w2" in key:
llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
elif "w3" in key:
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
else:
llama2_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None)
llama2_config_dict.pop("bias", None)
llama2_config_dict.pop("rope_scaling", None)
llama2_config_dict["model_type"] = "llama"
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir)
if __name__ == "__main__":
fire.Fire(llamafy_internlm2)

View File

@ -66,7 +66,7 @@ def check_dependencies() -> None:
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
require_version("gradio>=4.0.0", "To fix: pip install gradio>=4.0.0")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:

View File

@ -21,8 +21,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@ -50,7 +48,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn = gr.Button(variant="stop")
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
@ -73,4 +71,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
return elem_dict

View File

@ -6,7 +6,6 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
@ -24,7 +23,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True)
dataset = gr.Dropdown(multiselect=True, scale=4, allow_custom_value=True)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
@ -121,8 +120,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as freeze_tab:
with gr.Row():
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
name_module_trainable = gr.Textbox(value="all", scale=3)
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
name_module_trainable = gr.Textbox(value="all")
input_elems.update({num_layer_trainable, name_module_trainable})
elem_dict.update(
@ -140,8 +139,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox()
with gr.Row():
use_rslora = gr.Checkbox(scale=1)
use_dora = gr.Checkbox(scale=1)
with gr.Column(scale=1):
use_rslora = gr.Checkbox()
use_dora = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
@ -175,10 +176,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
elem_dict.update(
@ -187,11 +188,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox(scale=1)
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
galore_target = gr.Textbox(value="all", scale=3)
use_galore = gr.Checkbox()
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
@ -228,29 +229,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems),
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
@ -267,15 +245,27 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
output_box.change(
gen_plot,
[
engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.finetuning_type"),
output_dir,
],
loss_viewer,
queue=False,
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict

View File

@ -1344,6 +1344,11 @@ ALERTS = {
"ru": "Аргументы были сохранены по адресу: ",
"zh": "训练参数已保存至:",
},
"info_config_loaded": {
"en": "Arguments have been restored.",
"ru": "Аргументы были восстановлены.",
"zh": "训练参数已载入。",
},
"info_loading": {
"en": "Loading model...",
"ru": "Загрузка модели...",

View File

@ -2,7 +2,7 @@ import logging
import os
import time
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr
import transformers
@ -17,7 +17,7 @@ from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if TYPE_CHECKING:
@ -239,20 +239,22 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.Slider(visible=False)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield gen_cmd(args), gr.Slider(visible=False)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.Slider(visible=False)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
@ -261,54 +263,80 @@ class Runner:
self.thread.start()
yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(
get("top.model_name"),
get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.thread is not None and self.thread.is_alive():
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False)
yield {
output_box: ALERTS["info_aborting"][lang],
process_bar: gr.Slider(visible=False),
}
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
return_dict = {
output_box: self.logger_handler.log,
process_bar: update_process_bar(self.trainer_callback),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict
time.sleep(2)
if self.do_train:
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
if os.path.exists(os.path.join(output_path, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
return_dict = {
output_box: self._finalize(lang, finish_info),
process_bar: gr.Slider(visible=False),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
gr.Warning(error)
return error, gr.Slider(visible=False)
return {output_box: error}
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
@ -320,15 +348,16 @@ class Runner:
config_dict[elem_id] = value
save_path = save_args(config_path, config_dict)
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path)
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {self.manager.get_elem_by_id("top.lang"): lang}
return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value

View File

@ -1,13 +1,12 @@
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
from .locales import ALERTS
@ -36,7 +35,7 @@ def get_time() -> str:
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
return gr.Dropdown(value="None", interactive=False)
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
@ -74,11 +73,9 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
log_file = os.path.join(output_path, "trainer_log.jsonl")
if not os.path.isfile(log_file) or not is_matplotlib_available():
return
plt.close("all")
@ -88,13 +85,13 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info = json.loads(line)
log_info: Dict[str, Any] = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return None
return
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")