llama board: add unsloth

This commit is contained in:
hiyouga 2023-12-23 00:35:53 +08:00
parent 7aad0b889d
commit 9a18a85639
5 changed files with 13 additions and 26 deletions

View File

@ -77,8 +77,8 @@ class WebChatModel(ChatModel):
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
flash_attn=get("top.flash_attn"), flash_attn=(get("top.booster") == "flash_attn"),
shift_attn=get("top.shift_attn"), use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
) )
super().__init__(args) super().__init__(args)

View File

@ -28,10 +28,7 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default") template = gr.Dropdown(choices=list(templates.keys()), value="default")
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
with gr.Column():
flash_attn = gr.Checkbox(value=False)
shift_attn = gr.Checkbox(value=False)
model_name.change( model_name.change(
list_adapters, [model_name, finetuning_type], [adapter_path], queue=False list_adapters, [model_name, finetuning_type], [adapter_path], queue=False
@ -64,6 +61,5 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
flash_attn=flash_attn, booster=booster
shift_attn=shift_attn
) )

View File

@ -85,20 +85,12 @@ LOCALES = {
"label": "RoPE 插值方法" "label": "RoPE 插值方法"
} }
}, },
"flash_attn": { "booster": {
"en": { "en": {
"label": "Use FlashAttention-2" "label": "Booster"
}, },
"zh": { "zh": {
"label": "使用 FlashAttention-2" "label": "加速方式"
}
},
"shift_attn": {
"en": {
"label": "Use shift short attention (S^2-Attn)"
},
"zh": {
"label": "使用 shift short attention (S^2-Attn)"
} }
}, },
"training_stage": { "training_stage": {

View File

@ -25,9 +25,8 @@ class Manager:
self.all_elems["top"]["finetuning_type"], self.all_elems["top"]["finetuning_type"],
self.all_elems["top"]["quantization_bit"], self.all_elems["top"]["quantization_bit"],
self.all_elems["top"]["template"], self.all_elems["top"]["template"],
self.all_elems["top"]["flash_attn"], self.all_elems["top"]["rope_scaling"],
self.all_elems["top"]["shift_attn"], self.all_elems["top"]["booster"]
self.all_elems["top"]["rope_scaling"]
} }
def list_elems(self) -> List["Component"]: def list_elems(self) -> List["Component"]:

View File

@ -102,9 +102,9 @@ class Runner:
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
flash_attn=get("top.flash_attn"),
shift_attn=get("top.shift_attn"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"),
dataset_dir=get("train.dataset_dir"), dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")), dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"), cutoff_len=get("train.cutoff_len"),
@ -171,9 +171,9 @@ class Runner:
finetuning_type=get("top.finetuning_type"), finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
flash_attn=get("top.flash_attn"),
shift_attn=get("top.shift_attn"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"),
dataset_dir=get("eval.dataset_dir"), dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")), dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"), cutoff_len=get("eval.cutoff_len"),