add llava to llamaboard
This commit is contained in:
parent
e83e2fa897
commit
cd3a960f81
|
@ -1,4 +1,5 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
# add `--visual_inputs True` to load MLLM
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
|
|
@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
|
MLLM_LIST = ["LLaVA1.5"]
|
||||||
|
|
||||||
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
|
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
|
||||||
|
|
||||||
PEFT_METHODS = ["lora"]
|
PEFT_METHODS = ["lora"]
|
||||||
|
@ -566,6 +568,19 @@ register_model_group(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA1.5-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||||
|
},
|
||||||
|
"LLaVA1.5-13B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="vicuna",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B-v0.1": {
|
"Mistral-7B-v0.1": {
|
||||||
|
|
|
@ -79,6 +79,7 @@ class WebChatModel(ChatModel):
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
|
visual_inputs=get("top.visual_inputs"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
infer_backend=get("infer.infer_backend"),
|
infer_backend=get("infer.infer_backend"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..extras.constants import (
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
DEFAULT_MODULE,
|
DEFAULT_MODULE,
|
||||||
DEFAULT_TEMPLATE,
|
DEFAULT_TEMPLATE,
|
||||||
|
MLLM_LIST,
|
||||||
PEFT_METHODS,
|
PEFT_METHODS,
|
||||||
STAGES_USE_PAIR_DATA,
|
STAGES_USE_PAIR_DATA,
|
||||||
SUPPORTED_MODELS,
|
SUPPORTED_MODELS,
|
||||||
|
@ -105,6 +106,10 @@ def get_template(model_name: str) -> str:
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
|
def get_visual(model_name: str) -> bool:
|
||||||
|
return get_prefix(model_name) in MLLM_LIST
|
||||||
|
|
||||||
|
|
||||||
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||||
if finetuning_type not in PEFT_METHODS:
|
if finetuning_type not in PEFT_METHODS:
|
||||||
return gr.Dropdown(value=[], choices=[], interactive=False)
|
return gr.Dropdown(value=[], choices=[], interactive=False)
|
||||||
|
|
|
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
def create_chat_box(
|
def create_chat_box(
|
||||||
engine: "Engine", visible: bool = False
|
engine: "Engine", visible: bool = False
|
||||||
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
|
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||||
with gr.Column(visible=visible) as chat_box:
|
with gr.Column(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot(show_copy_button=True)
|
chatbot = gr.Chatbot(show_copy_button=True)
|
||||||
messages = gr.State([])
|
messages = gr.State([])
|
||||||
|
@ -29,7 +29,7 @@ def create_chat_box(
|
||||||
system = gr.Textbox(show_label=False)
|
system = gr.Textbox(show_label=False)
|
||||||
tools = gr.Textbox(show_label=False, lines=4)
|
tools = gr.Textbox(show_label=False, lines=4)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column() as image_box:
|
||||||
image = gr.Image(type="numpy")
|
image = gr.Image(type="numpy")
|
||||||
|
|
||||||
query = gr.Textbox(show_label=False, lines=8)
|
query = gr.Textbox(show_label=False, lines=8)
|
||||||
|
@ -55,13 +55,14 @@ def create_chat_box(
|
||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
chat_box,
|
|
||||||
chatbot,
|
chatbot,
|
||||||
messages,
|
messages,
|
||||||
dict(
|
dict(
|
||||||
|
chat_box=chat_box,
|
||||||
role=role,
|
role=role,
|
||||||
system=system,
|
system=system,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
image_box=image_box,
|
||||||
image=image,
|
image=image,
|
||||||
query=query,
|
query=query,
|
||||||
submit_btn=submit_btn,
|
submit_btn=submit_btn,
|
||||||
|
|
|
@ -27,6 +27,7 @@ def save_model(
|
||||||
adapter_path: List[str],
|
adapter_path: List[str],
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
template: str,
|
template: str,
|
||||||
|
visual_inputs: bool,
|
||||||
export_size: int,
|
export_size: int,
|
||||||
export_quantization_bit: int,
|
export_quantization_bit: int,
|
||||||
export_quantization_dataset: str,
|
export_quantization_dataset: str,
|
||||||
|
@ -66,6 +67,7 @@ def save_model(
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
adapter_name_or_path=adapter_name_or_path,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
template=template,
|
||||||
|
visual_inputs=visual_inputs,
|
||||||
export_dir=export_dir,
|
export_dir=export_dir,
|
||||||
export_hub_model_id=export_hub_model_id or None,
|
export_hub_model_id=export_hub_model_id or None,
|
||||||
export_size=export_size,
|
export_size=export_size,
|
||||||
|
@ -105,6 +107,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||||
engine.manager.get_elem_by_id("top.template"),
|
engine.manager.get_elem_by_id("top.template"),
|
||||||
|
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||||
export_size,
|
export_size,
|
||||||
export_quantization_bit,
|
export_quantization_bit,
|
||||||
export_quantization_dataset,
|
export_quantization_dataset,
|
||||||
|
|
|
@ -28,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
input_elems.update({infer_backend})
|
input_elems.update({infer_backend})
|
||||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||||
|
|
||||||
chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
elem_dict.update(chat_elems)
|
||||||
|
|
||||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
|
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
|
||||||
)
|
)
|
||||||
|
|
||||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||||
lambda: ([], []), outputs=[chatbot, messages]
|
lambda: ([], []), outputs=[chatbot, messages]
|
||||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
|
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||||
|
|
||||||
|
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||||
|
lambda enabled: gr.Column(visible=enabled),
|
||||||
|
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||||
|
[chat_elems["image_box"]],
|
||||||
|
)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
||||||
from ...data import templates
|
from ...data import templates
|
||||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ..common import get_model_path, get_template, list_adapters, save_config
|
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
||||||
from ..utils import can_quantize
|
from ..utils import can_quantize
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
|
|
||||||
with gr.Accordion(open=False) as advanced_tab:
|
with gr.Accordion(open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
|
||||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
|
||||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none")
|
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||||
|
visual_inputs = gr.Checkbox(scale=1)
|
||||||
|
|
||||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||||
get_model_path, [model_name], [model_path], queue=False
|
get_model_path, [model_name], [model_path], queue=False
|
||||||
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
|
).then(get_template, [model_name], [template], queue=False).then(
|
||||||
|
get_visual, [model_name], [visual_inputs], queue=False
|
||||||
|
) # do not save config since the below line will save
|
||||||
|
|
||||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||||
|
|
||||||
|
@ -59,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
template=template,
|
template=template,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
|
visual_inputs=visual_inputs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,6 +43,7 @@ class Engine:
|
||||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||||
|
init_dict["infer.image_box"] = {"visible": False}
|
||||||
|
|
||||||
if user_config.get("last_model", None):
|
if user_config.get("last_model", None):
|
||||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||||
|
|
|
@ -58,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
|
||||||
lang = gr.Dropdown(choices=["en", "zh"])
|
lang = gr.Dropdown(choices=["en", "zh"])
|
||||||
engine.manager.add_elems("top", dict(lang=lang))
|
engine.manager.add_elems("top", dict(lang=lang))
|
||||||
|
|
||||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
_, _, chat_elems = create_chat_box(engine, visible=True)
|
||||||
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
|
engine.manager.add_elems("infer", chat_elems)
|
||||||
|
|
||||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||||
|
|
|
@ -129,6 +129,17 @@ LOCALES = {
|
||||||
"label": "加速方式",
|
"label": "加速方式",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"visual_inputs": {
|
||||||
|
"en": {
|
||||||
|
"label": "Visual inputs",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "визуальные входы",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "图像输入",
|
||||||
|
},
|
||||||
|
},
|
||||||
"training_stage": {
|
"training_stage": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Stage",
|
"label": "Stage",
|
||||||
|
|
|
@ -60,4 +60,5 @@ class Manager:
|
||||||
self._id_to_elem["top.template"],
|
self._id_to_elem["top.template"],
|
||||||
self._id_to_elem["top.rope_scaling"],
|
self._id_to_elem["top.rope_scaling"],
|
||||||
self._id_to_elem["top.booster"],
|
self._id_to_elem["top.booster"],
|
||||||
|
self._id_to_elem["top.visual_inputs"],
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,6 +124,7 @@ class Runner:
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
|
visual_inputs=get("top.visual_inputs"),
|
||||||
dataset_dir=get("train.dataset_dir"),
|
dataset_dir=get("train.dataset_dir"),
|
||||||
dataset=",".join(get("train.dataset")),
|
dataset=",".join(get("train.dataset")),
|
||||||
cutoff_len=get("train.cutoff_len"),
|
cutoff_len=get("train.cutoff_len"),
|
||||||
|
@ -224,6 +225,7 @@ class Runner:
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
|
visual_inputs=get("top.visual_inputs"),
|
||||||
dataset_dir=get("eval.dataset_dir"),
|
dataset_dir=get("eval.dataset_dir"),
|
||||||
dataset=",".join(get("eval.dataset")),
|
dataset=",".join(get("eval.dataset")),
|
||||||
cutoff_len=get("eval.cutoff_len"),
|
cutoff_len=get("eval.cutoff_len"),
|
||||||
|
|
Loading…
Reference in New Issue