diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 3c5c8e18..08027e38 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -77,8 +77,8 @@ class WebChatModel(ChatModel): finetuning_type=get("top.finetuning_type"), quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), - flash_attn=get("top.flash_attn"), - shift_attn=get("top.shift_attn"), + flash_attn=(get("top.booster") == "flash_attn"), + use_unsloth=(get("top.booster") == "unsloth"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None ) super().__init__(args) diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index df78f2f5..74441ab2 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -28,10 +28,7 @@ def create_top() -> Dict[str, "Component"]: quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") template = gr.Dropdown(choices=list(templates.keys()), value="default") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") - - with gr.Column(): - flash_attn = gr.Checkbox(value=False) - shift_attn = gr.Checkbox(value=False) + booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none") model_name.change( list_adapters, [model_name, finetuning_type], [adapter_path], queue=False @@ -64,6 +61,5 @@ def create_top() -> Dict[str, "Component"]: quantization_bit=quantization_bit, template=template, rope_scaling=rope_scaling, - flash_attn=flash_attn, - shift_attn=shift_attn + booster=booster ) diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 888ed4a9..3d01a0f2 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -85,20 +85,12 @@ LOCALES = { "label": "RoPE 插值方法" } }, - "flash_attn": { + "booster": { "en": { - "label": "Use FlashAttention-2" + "label": "Booster" }, "zh": { - "label": "使用 FlashAttention-2" - } - }, - "shift_attn": { - "en": { - "label": "Use shift short attention (S^2-Attn)" - }, - "zh": { - "label": "使用 shift short attention (S^2-Attn)" + "label": "加速方式" } }, "training_stage": { diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index 833afceb..18767f0d 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -25,9 +25,8 @@ class Manager: self.all_elems["top"]["finetuning_type"], self.all_elems["top"]["quantization_bit"], self.all_elems["top"]["template"], - self.all_elems["top"]["flash_attn"], - self.all_elems["top"]["shift_attn"], - self.all_elems["top"]["rope_scaling"] + self.all_elems["top"]["rope_scaling"], + self.all_elems["top"]["booster"] } def list_elems(self) -> List["Component"]: diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 2e96e504..483d709d 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -102,9 +102,9 @@ class Runner: finetuning_type=get("top.finetuning_type"), quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), - flash_attn=get("top.flash_attn"), - shift_attn=get("top.shift_attn"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + flash_attn=(get("top.booster") == "flash_attn"), + use_unsloth=(get("top.booster") == "unsloth"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")), cutoff_len=get("train.cutoff_len"), @@ -171,9 +171,9 @@ class Runner: finetuning_type=get("top.finetuning_type"), quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), - flash_attn=get("top.flash_attn"), - shift_attn=get("top.shift_attn"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + flash_attn=(get("top.booster") == "flash_attn"), + use_unsloth=(get("top.booster") == "unsloth"), dataset_dir=get("eval.dataset_dir"), dataset=",".join(get("eval.dataset")), cutoff_len=get("eval.cutoff_len"),