diff --git a/requirements.txt b/requirements.txt index 1d36fd33..2eb25a2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=1.13.1 -transformers>=4.30.0 +transformers>=4.31.0 datasets>=2.12.0 accelerate>=0.21.0 peft>=0.4.0 diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index fa0abb93..0570b33c 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -check_min_version("4.30.0") +check_min_version("4.31.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index 1499c171..b6f3774e 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -26,7 +26,10 @@ class WebChatModel(ChatModel): finetuning_type: str, quantization_bit: str, template: str, - system_prompt: str + system_prompt: str, + flash_attn: bool, + shift_attn: bool, + rope_scaling: str ): if self.model is not None: yield ALERTS["err_exists"][lang] @@ -42,9 +45,7 @@ class WebChatModel(ChatModel): return if checkpoints: - checkpoint_dir = ",".join( - [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] - ) + checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) else: checkpoint_dir = None @@ -55,7 +56,10 @@ class WebChatModel(ChatModel): finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None, template=template, - system_prompt=system_prompt + system_prompt=system_prompt, + flash_attn=flash_attn, + shift_attn=shift_attn, + rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None ) super().__init__(args) diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 95d7b613..ef8d2adc 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -3,8 +3,14 @@ import os from typing import Any, Dict, Optional import gradio as gr -from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME -from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.utils import ( + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + ADAPTER_WEIGHTS_NAME, + ADAPTER_SAFE_WEIGHTS_NAME +) from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES @@ -14,6 +20,14 @@ DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user.config" DATA_CONFIG = "dataset_info.json" +CKPT_NAMES = [ + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + ADAPTER_WEIGHTS_NAME, + ADAPTER_SAFE_WEIGHTS_NAME +] def get_save_dir(*args) -> os.PathLike: @@ -61,10 +75,7 @@ def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: for checkpoint in os.listdir(save_dir): if ( os.path.isdir(os.path.join(save_dir, checkpoint)) - and any([ - os.path.isfile(os.path.join(save_dir, checkpoint, name)) - for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) - ]) + and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES]) ): checkpoints.append(checkpoint) return gr.update(value=[], choices=checkpoints) @@ -75,6 +86,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: return json.load(f) except: + print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir)) return {} diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index d37fe746..089d02e5 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -57,6 +57,9 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict top_elems["quantization_bit"], top_elems["template"], top_elems["system_prompt"], + top_elems["flash_attn"], + top_elems["shift_attn"], + top_elems["rope_scaling"], dataset_dir, dataset, cutoff_len, diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 489ccf2e..49505b67 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -28,7 +28,10 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component" top_elems["finetuning_type"], top_elems["quantization_bit"], top_elems["template"], - top_elems["system_prompt"] + top_elems["system_prompt"], + top_elems["flash_attn"], + top_elems["shift_attn"], + top_elems["rope_scaling"] ], [info_box] ).then( diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 62c1f9c9..e675009d 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -26,10 +26,16 @@ def create_top() -> Dict[str, "Component"]: with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Row(): - quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) system_prompt = gr.Textbox(scale=2) + with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab: + with gr.Row(): + flash_attn = gr.Checkbox(value=False) + shift_attn = gr.Checkbox(value=False) + rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none") + lang.change(save_config, [lang, model_name, model_path]) model_name.change( @@ -62,5 +68,9 @@ def create_top() -> Dict[str, "Component"]: advanced_tab=advanced_tab, quantization_bit=quantization_bit, template=template, - system_prompt=system_prompt + system_prompt=system_prompt, + llama_tab=llama_tab, + flash_attn=flash_attn, + shift_attn=shift_attn, + rope_scaling=rope_scaling ) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index d12f8c9f..050a44da 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -55,8 +55,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) - flash_attn = gr.Checkbox(value=False) - rope_scaling = gr.Checkbox(value=False) with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Row(): @@ -67,8 +65,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Row(): - dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2) - reward_model = gr.Dropdown(scale=2) + dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) + reward_model = gr.Dropdown(scale=3) refresh_btn = gr.Button(scale=1) refresh_btn.click( @@ -105,6 +103,9 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic top_elems["quantization_bit"], top_elems["template"], top_elems["system_prompt"], + top_elems["flash_attn"], + top_elems["shift_attn"], + top_elems["rope_scaling"], training_stage, dataset_dir, dataset, @@ -121,8 +122,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic logging_steps, save_steps, warmup_steps, - flash_attn, - rope_scaling, lora_rank, lora_dropout, lora_target, @@ -167,8 +166,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, - flash_attn=flash_attn, - rope_scaling=rope_scaling, lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 5ac3cd2e..93005e52 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -59,12 +59,12 @@ LOCALES = { }, "quantization_bit": { "en": { - "label": "Quantization bit (optional)", - "info": "Enable 4/8-bit model quantization." + "label": "Quantization bit", + "info": "Enable 4/8-bit model quantization (QLoRA)." }, "zh": { - "label": "量化等级(非必填)", - "info": "启用 4/8 比特模型量化。" + "label": "量化等级", + "info": "启用 4/8 比特模型量化(QLoRA)。" } }, "template": { @@ -87,6 +87,38 @@ LOCALES = { "info": "默认使用的系统提示词" } }, + "llama_tab": { + "en": { + "label": "Model configurations (LLaMA only)" + }, + "zh": { + "label": "模型设置(仅LLaMA)" + } + }, + "flash_attn": { + "en": { + "label": "Use FlashAttention-2" + }, + "zh": { + "label": "使用 FlashAttention-2" + } + }, + "shift_attn": { + "en": { + "label": "Use shift short attention (S^2-Attn)" + }, + "zh": { + "label": "使用 shift short attention (S^2-Attn)" + } + }, + "rope_scaling": { + "en": { + "label": "RoPE scaling" + }, + "zh": { + "label": "RoPE 插值方法" + } + }, "training_stage": { "en": { "label": "Stage", @@ -277,22 +309,6 @@ LOCALES = { "info": "学习率预热采用的步数。" } }, - "flash_attn": { - "en": { - "label": "Use FlashAttention-2" - }, - "zh": { - "label": "使用 FlashAttention-2" - } - }, - "rope_scaling": { - "en": { - "label": "Use RoPE scaling" - }, - "zh": { - "label": "使用 RoPE 插值" - } - }, "lora_tab": { "en": { "label": "LoRA configurations" @@ -362,11 +378,11 @@ LOCALES = { "reward_model": { "en": { "label": "Reward model", - "info": "Checkpoint of the reward model for PPO training." + "info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)" }, "zh": { "label": "奖励模型", - "info": "PPO 训练中奖励模型的断点路径。" + "info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)" } }, "cmd_preview_btn": { diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index e2945158..3cc32332 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -70,6 +70,9 @@ class Runner: quantization_bit: str, template: str, system_prompt: str, + flash_attn: bool, + shift_attn: bool, + rope_scaling: str, training_stage: str, dataset_dir: str, dataset: List[str], @@ -86,8 +89,6 @@ class Runner: logging_steps: int, save_steps: int, warmup_steps: int, - flash_attn: bool, - rope_scaling: bool, lora_rank: int, lora_dropout: float, lora_target: str, @@ -97,9 +98,7 @@ class Runner: output_dir: str ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: - checkpoint_dir = ",".join( - [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] - ) + checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) else: checkpoint_dir = None @@ -119,6 +118,9 @@ class Runner: quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, template=template, system_prompt=system_prompt, + flash_attn=flash_attn, + shift_attn=shift_attn, + rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None, dataset_dir=dataset_dir, dataset=",".join(dataset), cutoff_len=cutoff_len, @@ -132,8 +134,6 @@ class Runner: logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, - flash_attn=flash_attn, - rope_scaling="linear" if rope_scaling else None, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), @@ -168,6 +168,9 @@ class Runner: quantization_bit: str, template: str, system_prompt: str, + flash_attn: bool, + shift_attn: bool, + rope_scaling: str, dataset_dir: str, dataset: List[str], cutoff_len: int, @@ -179,9 +182,7 @@ class Runner: temperature: float ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: - checkpoint_dir = ",".join( - [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] - ) + checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints)) else: checkpoint_dir = None @@ -202,6 +203,9 @@ class Runner: quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, template=template, system_prompt=system_prompt, + flash_attn=flash_attn, + shift_attn=shift_attn, + rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None, dataset_dir=dataset_dir, dataset=",".join(dataset), cutoff_len=cutoff_len,