llama board: add unsloth
This commit is contained in:
parent
7aad0b889d
commit
9a18a85639
|
@ -77,8 +77,8 @@ class WebChatModel(ChatModel):
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn=get("top.flash_attn"),
|
flash_attn=(get("top.booster") == "flash_attn"),
|
||||||
shift_attn=get("top.shift_attn"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||||
)
|
)
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
|
@ -28,10 +28,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||||
|
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
|
||||||
with gr.Column():
|
|
||||||
flash_attn = gr.Checkbox(value=False)
|
|
||||||
shift_attn = gr.Checkbox(value=False)
|
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
|
||||||
|
@ -64,6 +61,5 @@ def create_top() -> Dict[str, "Component"]:
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
template=template,
|
template=template,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
flash_attn=flash_attn,
|
booster=booster
|
||||||
shift_attn=shift_attn
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -85,20 +85,12 @@ LOCALES = {
|
||||||
"label": "RoPE 插值方法"
|
"label": "RoPE 插值方法"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"flash_attn": {
|
"booster": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Use FlashAttention-2"
|
"label": "Booster"
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "使用 FlashAttention-2"
|
"label": "加速方式"
|
||||||
}
|
|
||||||
},
|
|
||||||
"shift_attn": {
|
|
||||||
"en": {
|
|
||||||
"label": "Use shift short attention (S^2-Attn)"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "使用 shift short attention (S^2-Attn)"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"training_stage": {
|
"training_stage": {
|
||||||
|
|
|
@ -25,9 +25,8 @@ class Manager:
|
||||||
self.all_elems["top"]["finetuning_type"],
|
self.all_elems["top"]["finetuning_type"],
|
||||||
self.all_elems["top"]["quantization_bit"],
|
self.all_elems["top"]["quantization_bit"],
|
||||||
self.all_elems["top"]["template"],
|
self.all_elems["top"]["template"],
|
||||||
self.all_elems["top"]["flash_attn"],
|
self.all_elems["top"]["rope_scaling"],
|
||||||
self.all_elems["top"]["shift_attn"],
|
self.all_elems["top"]["booster"]
|
||||||
self.all_elems["top"]["rope_scaling"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def list_elems(self) -> List["Component"]:
|
def list_elems(self) -> List["Component"]:
|
||||||
|
|
|
@ -102,9 +102,9 @@ class Runner:
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn=get("top.flash_attn"),
|
|
||||||
shift_attn=get("top.shift_attn"),
|
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
|
flash_attn=(get("top.booster") == "flash_attn"),
|
||||||
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
dataset_dir=get("train.dataset_dir"),
|
dataset_dir=get("train.dataset_dir"),
|
||||||
dataset=",".join(get("train.dataset")),
|
dataset=",".join(get("train.dataset")),
|
||||||
cutoff_len=get("train.cutoff_len"),
|
cutoff_len=get("train.cutoff_len"),
|
||||||
|
@ -171,9 +171,9 @@ class Runner:
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn=get("top.flash_attn"),
|
|
||||||
shift_attn=get("top.shift_attn"),
|
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
|
flash_attn=(get("top.booster") == "flash_attn"),
|
||||||
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
dataset_dir=get("eval.dataset_dir"),
|
dataset_dir=get("eval.dataset_dir"),
|
||||||
dataset=",".join(get("eval.dataset")),
|
dataset=",".join(get("eval.dataset")),
|
||||||
cutoff_len=get("eval.cutoff_len"),
|
cutoff_len=get("eval.cutoff_len"),
|
||||||
|
|
Loading…
Reference in New Issue