add llava to llamaboard
This commit is contained in:
parent
e83e2fa897
commit
cd3a960f81
|
@ -1,4 +1,5 @@
|
|||
#!/bin/bash
|
||||
# add `--visual_inputs True` to load MLLM
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
|
|||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
MLLM_LIST = ["LLaVA1.5"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
@ -566,6 +568,19 @@ register_model_group(
|
|||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||
},
|
||||
"LLaVA1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||
},
|
||||
},
|
||||
template="vicuna",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B-v0.1": {
|
||||
|
|
|
@ -79,6 +79,7 @@ class WebChatModel(ChatModel):
|
|||
template=get("top.template"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..extras.constants import (
|
|||
DATA_CONFIG,
|
||||
DEFAULT_MODULE,
|
||||
DEFAULT_TEMPLATE,
|
||||
MLLM_LIST,
|
||||
PEFT_METHODS,
|
||||
STAGES_USE_PAIR_DATA,
|
||||
SUPPORTED_MODELS,
|
||||
|
@ -105,6 +106,10 @@ def get_template(model_name: str) -> str:
|
|||
return "default"
|
||||
|
||||
|
||||
def get_visual(model_name: str) -> bool:
|
||||
return get_prefix(model_name) in MLLM_LIST
|
||||
|
||||
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type not in PEFT_METHODS:
|
||||
return gr.Dropdown(value=[], choices=[], interactive=False)
|
||||
|
|
|
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
|
@ -29,7 +29,7 @@ def create_chat_box(
|
|||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=4)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Column() as image_box:
|
||||
image = gr.Image(type="numpy")
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
|
@ -55,13 +55,14 @@ def create_chat_box(
|
|||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||
|
||||
return (
|
||||
chat_box,
|
||||
chatbot,
|
||||
messages,
|
||||
dict(
|
||||
chat_box=chat_box,
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
image_box=image_box,
|
||||
image=image,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
|
|
|
@ -27,6 +27,7 @@ def save_model(
|
|||
adapter_path: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
|
@ -66,6 +67,7 @@ def save_model(
|
|||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
visual_inputs=visual_inputs,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id or None,
|
||||
export_size=export_size,
|
||||
|
@ -105,6 +107,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||
export_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
|
|
|
@ -28,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
input_elems.update({infer_backend})
|
||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(chat_elems)
|
||||
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
|
||||
)
|
||||
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, messages]
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||
lambda enabled: gr.Column(visible=enabled),
|
||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||
[chat_elems["image_box"]],
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
|||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
|
@ -30,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none")
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
|
||||
).then(get_template, [model_name], [template], queue=False).then(
|
||||
get_visual, [model_name], [visual_inputs], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
|
@ -59,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
|
|||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
visual_inputs=visual_inputs,
|
||||
)
|
||||
|
|
|
@ -43,6 +43,7 @@ class Engine:
|
|||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
|
|
|
@ -58,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
|
|||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.add_elems("top", dict(lang=lang))
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
|
||||
_, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.add_elems("infer", chat_elems)
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
|
|
|
@ -129,6 +129,17 @@ LOCALES = {
|
|||
"label": "加速方式",
|
||||
},
|
||||
},
|
||||
"visual_inputs": {
|
||||
"en": {
|
||||
"label": "Visual inputs",
|
||||
},
|
||||
"ru": {
|
||||
"label": "визуальные входы",
|
||||
},
|
||||
"zh": {
|
||||
"label": "图像输入",
|
||||
},
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
|
|
|
@ -60,4 +60,5 @@ class Manager:
|
|||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
self._id_to_elem["top.visual_inputs"],
|
||||
}
|
||||
|
|
|
@ -124,6 +124,7 @@ class Runner:
|
|||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
|
@ -224,6 +225,7 @@ class Runner:
|
|||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
|
|
Loading…
Reference in New Issue