diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index ffdc8827..cd22943f 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -10,7 +10,13 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] METHODS = ["full", "freeze", "lora"] -STAGES = ["Supervised Finetuning", "Reward Modeling", "PPO", "DPO", "Pretraining"] +STAGES = [ + "SFT", + "Reward Modeling", + "PPO", + "DPO", + "Pre-Training" +] SUPPORTED_MODELS = { "LLaMA-7B": "huggyllama/llama-7b", @@ -23,6 +29,10 @@ SUPPORTED_MODELS = { "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", + "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", + "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", + "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", + "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", "BLOOM-560M": "bigscience/bloom-560m", "BLOOM-3B": "bigscience/bloom-3b", "BLOOM-7B1": "bigscience/bloom-7b1", @@ -41,12 +51,13 @@ SUPPORTED_MODELS = { "Qwen-7B": "Qwen/Qwen-7B", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "XVERSE-13B": "xverse/XVERSE-13B", - "ChatGLM2-6B": "THUDM/chatglm2-6b" + "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b" } DEFAULT_MODULE = { "LLaMA": "q_proj,v_proj", "LLaMA2": "q_proj,v_proj", + "ChineseLLaMA2": "q_proj,v_proj", "BLOOM": "query_key_value", "BLOOMZ": "query_key_value", "Falcon": "query_key_value", @@ -59,28 +70,9 @@ DEFAULT_MODULE = { DEFAULT_TEMPLATE = { "LLaMA2": "llama2", + "ChineseLLaMA2": "llama2_zh", "Baichuan": "baichuan", "InternLM": "intern", "Qwen": "chatml", "ChatGLM2": "chatglm2" } - -# huggingface model name prefix 2 template -DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL = { - "Llama-2": "llama2", - "chinese-alpaca-2": "llama2_zh", - "alpaca-7b-wdiff": "alpaca", - "vicuna": "vicuna", - "BELLE": "belle", - "Chinese-LLaMA-2": "linly", - "BiLLa": "billa", - "Ziya": "ziya", - "aquilachat": "aquila", - "internlm": "intern", - "aquilachat": "aquila", - "internlm": "intern", - "Baichuan":"baichuan", - "starchat":"starchat", - "Qwen":"chatml", - "chatglm2":"chatglm2" -} \ No newline at end of file diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index b0a7365c..b3e0e4b1 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -95,7 +95,6 @@ def prepare_model_for_training( use_gradient_checkpointing: Optional[bool] = True, layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES ) -> "PreTrainedModel": - for name, param in model.named_parameters(): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): param.data = param.data.to(torch.float32) @@ -112,9 +111,6 @@ def prepare_model_for_training( model.config.use_cache = False # turn off when gradient checkpointing is enabled if finetuning_type != "full" and hasattr(model, output_layer_name): - if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"): - model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728) - output_layer: torch.nn.Linear = getattr(model, output_layer_name) input_dtype = output_layer.weight.dtype diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index e8d228a4..bb2886d5 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -273,8 +273,8 @@ register_template( r""" -Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b - https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 +Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 + https://huggingface.co/ziqingyang/chinese-alpaca-2-7b """ register_template( name="llama2_zh", diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 728fbd6b..965a690b 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS DEFAULT_CACHE_DIR = "cache" @@ -48,20 +48,10 @@ def get_model_path(model_name: str) -> str: return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) -def get_template( - model_name: str, -) -> str: - if model_name == "Custom": - model_name_or_path = get_model_path(model_name) - # get last dir - basename = os.path.basename(model_name_or_path) - # prefix match - for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items(): - if basename.startswith(k): - return v - return "default" - - return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default") +def get_template(model_name: str) -> str: + if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE: + return DEFAULT_TEMPLATE[model_name.split("-")[0]] + return "default" def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 76ead7ee..7f3c6faa 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -4,7 +4,7 @@ import gradio as gr from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates -from llmtuner.webui.common import list_checkpoint, get_model_path, save_config, get_template +from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config from llmtuner.webui.utils import can_quantize if TYPE_CHECKING: @@ -36,10 +36,11 @@ def create_top() -> Dict[str, "Component"]: list_checkpoint, [model_name, finetuning_type], [checkpoints] ).then( get_model_path, [model_name], [model_path] + ).then( + get_template, [model_name], [template] ) # do not save config since the below line will save model_path.change(save_config, [lang, model_name, model_path]) - model_path.change(get_template, [model_name], [template]) finetuning_type.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 76f8ff94..6aaeecbb 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -3,10 +3,10 @@ from transformers.trainer_utils import SchedulerType import gradio as gr +from llmtuner.extras.constants import STAGES from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.utils import can_preview, get_preview, gen_plot -from llmtuner.extras.constants import STAGES if TYPE_CHECKING: from gradio.components import Component @@ -15,9 +15,7 @@ if TYPE_CHECKING: def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: with gr.Row(): - stage = gr.Dropdown(choices=STAGES, - value="Supervised Finetuning", scale=2) - + training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) data_preview_btn = gr.Button(interactive=False, scale=1) @@ -104,7 +102,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic top_elems["quantization_bit"], top_elems["template"], top_elems["source_prefix"], - stage, + training_stage, dataset_dir, dataset, max_source_length, @@ -145,7 +143,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic ) return dict( - stage=stage, + training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn, diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index d95fde29..7a58c4c7 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -87,6 +87,16 @@ LOCALES = { "info": "默认使用的系统提示词" } }, + "training_stage": { + "en": { + "label": "Stage", + "info": "The stage to perform in training." + }, + "zh": { + "label": "训练阶段", + "info": "目前采用的训练方式。" + } + }, "dataset_dir": { "en": { "label": "Data dir", @@ -343,16 +353,6 @@ LOCALES = { "label": "RLHF 参数设置" } }, - "rlhf_method": { - "en": { - "label": "RLHF method", - "info": "The RLHF algorithm to adopt." - }, - "zh": { - "label": "RLHF 方法", - "info": "RLHF 阶段使用的算法。" - } - }, "dpo_beta": { "en": { "label": "DPO beta", @@ -546,15 +546,7 @@ LOCALES = { "zh": { "value": "开始导出" } - }, - "stage": { - "en": { - "label": "train stage" - }, - "zh": { - "label": "训练阶段" - } - }, + } } diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 07811258..49fed19b 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -8,11 +8,11 @@ from transformers.trainer import TRAINING_ARGS_NAME from typing import Any, Dict, Generator, List, Tuple from llmtuner.extras.callbacks import LogCallback -from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL +from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp -from llmtuner.webui.common import get_model_path, get_save_dir, get_template +from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.locales import ALERTS from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar @@ -70,7 +70,7 @@ class Runner: quantization_bit: str, template: str, source_prefix: str, - stage: str, + training_stage: str, dataset_dir: str, dataset: List[str], max_source_length: int, @@ -138,21 +138,21 @@ class Runner: ) args[compute_type] = True - if stage == "Pretraining": - args["stage"] = "pt" - if stage == "Reward Modeling": + if training_stage == "Reward Modeling": args["stage"] = "rm" args["resume_lora_training"] = False - elif stage == "PPO": + elif training_stage == "PPO": args["stage"] = "ppo" args["resume_lora_training"] = False args["reward_model"] = reward_model args["padding_side"] = "left" val_size = 0 - elif stage == "DPO": + elif training_stage == "DPO": args["stage"] = "dpo" args["resume_lora_training"] = False args["dpo_beta"] = dpo_beta + elif training_stage == "Pre-Training": + args["stage"] = "pt" if val_size > 1e-6: args["val_size"] = val_size