add llava to llamaboard

This commit is contained in:
hiyouga 2024-04-26 06:41:35 +08:00
parent e83e2fa897
commit cd3a960f81
13 changed files with 66 additions and 15 deletions

View File

@ -1,4 +1,5 @@
#!/bin/bash
# add `--visual_inputs True` to load MLLM
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"]
@ -566,6 +568,19 @@ register_model_group(
)
register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
)
register_model_group(
models={
"Mistral-7B-v0.1": {

View File

@ -79,6 +79,7 @@ class WebChatModel(ChatModel):
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
)

View File

@ -9,6 +9,7 @@ from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
MLLM_LIST,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS,
@ -105,6 +106,10 @@ def get_template(model_name: str) -> str:
return "default"
def get_visual(model_name: str) -> bool:
return get_prefix(model_name) in MLLM_LIST
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value=[], choices=[], interactive=False)

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([])
@ -29,7 +29,7 @@ def create_chat_box(
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=4)
with gr.Column():
with gr.Column() as image_box:
image = gr.Image(type="numpy")
query = gr.Textbox(show_label=False, lines=8)
@ -55,13 +55,14 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
return (
chat_box,
chatbot,
messages,
dict(
chat_box=chat_box,
role=role,
system=system,
tools=tools,
image_box=image_box,
image=image,
query=query,
submit_btn=submit_btn,

View File

@ -27,6 +27,7 @@ def save_model(
adapter_path: List[str],
finetuning_type: str,
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
@ -66,6 +67,7 @@ def save_model(
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
@ -105,6 +107,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,

View File

@ -28,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
)
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
)
return elem_dict

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, list_adapters, save_config
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..utils import can_quantize
@ -30,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default")
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none")
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
@ -59,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)

View File

@ -43,6 +43,7 @@ class Engine:
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False}
if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}

View File

@ -58,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elems("top", dict(lang=lang))
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
_, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", chat_elems)
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)

View File

@ -129,6 +129,17 @@ LOCALES = {
"label": "加速方式",
},
},
"visual_inputs": {
"en": {
"label": "Visual inputs",
},
"ru": {
"label": "визуальные входы",
},
"zh": {
"label": "图像输入",
},
},
"training_stage": {
"en": {
"label": "Stage",

View File

@ -60,4 +60,5 @@ class Manager:
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
}

View File

@ -124,6 +124,7 @@ class Runner:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
@ -224,6 +225,7 @@ class Runner:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),