diff --git a/data/dataset_info.json b/data/dataset_info.json index 3eaf920e..5fd4fb1f 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -1,23 +1,28 @@ { "alpaca_en": { "file_name": "alpaca_data_en_52k.json", - "file_sha1": "607f94a7f581341e59685aef32f531095232cf23" + "file_sha1": "607f94a7f581341e59685aef32f531095232cf23", + "stage": "sft" }, "alpaca_zh": { "file_name": "alpaca_data_zh_51k.json", - "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311" + "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311", + "stage": "sft" }, "alpaca_gpt4_en": { "file_name": "alpaca_gpt4_data_en.json", - "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a" + "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a", + "stage": "sft" }, "alpaca_gpt4_zh": { "file_name": "alpaca_gpt4_data_zh.json", - "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845" + "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845", + "stage": "sft" }, "self_cognition": { "file_name": "self_cognition.json", - "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67" + "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67", + "stage": "sft" }, "oaast_sft": { "file_name": "oaast_sft.json", @@ -27,7 +32,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "oaast_sft_zh": { "file_name": "oaast_sft_zh.json", @@ -37,7 +43,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "sharegpt_zh": { "file_name": "sharegpt_zh_27k.json", @@ -47,7 +54,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "lima": { "file_name": "lima.json", @@ -57,7 +65,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "example": { "script_url": "example_dataset", @@ -66,25 +75,32 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "guanaco": { - "hf_hub_url": "JosephusCheung/GuanacoDataset" + "hf_hub_url": "JosephusCheung/GuanacoDataset", + "stage": "sft" }, "belle_0.5m": { - "hf_hub_url": "BelleGroup/train_0.5M_CN" + "hf_hub_url": "BelleGroup/train_0.5M_CN", + "stage": "sft" }, "belle_1m": { - "hf_hub_url": "BelleGroup/train_1M_CN" + "hf_hub_url": "BelleGroup/train_1M_CN", + "stage": "sft" }, "belle_2m": { - "hf_hub_url": "BelleGroup/train_2M_CN" + "hf_hub_url": "BelleGroup/train_2M_CN", + "stage": "sft" }, "belle_dialog": { - "hf_hub_url": "BelleGroup/generated_chat_0.4M" + "hf_hub_url": "BelleGroup/generated_chat_0.4M", + "stage": "sft" }, "belle_math": { - "hf_hub_url": "BelleGroup/school_math_0.25M" + "hf_hub_url": "BelleGroup/school_math_0.25M", + "stage": "sft" }, "belle_multiturn": { "script_url": "belle_multiturn", @@ -93,7 +109,8 @@ "query": "", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "firefly": { "hf_hub_url": "YeungNLP/firefly-train-1.1M", @@ -102,13 +119,16 @@ "query": "", "response": "target", "history": "" - } + }, + "stage": "sft" }, "codealpaca": { - "hf_hub_url": "sahil2801/CodeAlpaca-20k" + "hf_hub_url": "sahil2801/CodeAlpaca-20k", + "stage": "sft" }, "alpaca_cot": { - "hf_hub_url": "QingyiSi/Alpaca-CoT" + "hf_hub_url": "QingyiSi/Alpaca-CoT", + "stage": "sft" }, "webqa": { "hf_hub_url": "suolyer/webqa", @@ -117,7 +137,8 @@ "query": "", "response": "output", "history": "" - } + }, + "stage": "sft" }, "ultra_chat": { "script_url": "ultra_chat", @@ -126,18 +147,22 @@ "query": "", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "novel_tokens512_50k": { - "hf_hub_url": "zxbsmk/webnovel_cn" + "hf_hub_url": "zxbsmk/webnovel_cn", + "stage": "sft" }, "comparison_gpt4_en": { "file_name": "comparison_gpt4_data_en.json", - "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae" + "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae", + "stage": "rm" }, "comparison_gpt4_zh": { "file_name": "comparison_gpt4_data_zh.json", - "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd" + "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd", + "stage": "rm" }, "hh_rlhf_en": { "script_url": "hh_rlhf_en", @@ -146,7 +171,8 @@ "query": "", "response": "output", "history": "history" - } + }, + "stage": "rm" }, "oaast_rm": { "file_name": "oaast_rm.json", @@ -156,7 +182,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "rm" }, "oaast_rm_zh": { "file_name": "oaast_rm_zh.json", @@ -166,7 +193,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "rm" }, "wiki_demo": { "file_name": "wiki_demo.txt", @@ -176,7 +204,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "refinedweb": { "hf_hub_url": "tiiuae/falcon-refinedweb", @@ -185,7 +214,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "starcoder": { "hf_hub_url": "bigcode/starcoderdata", @@ -194,7 +224,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "wikipedia_en": { "hf_hub_url": "olm/olm-wikipedia-20221220", @@ -203,7 +234,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "wikipedia_zh": { "hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered", @@ -212,6 +244,7 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" } } diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index cd22943f..f10aaaa3 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -18,6 +18,14 @@ STAGES = [ "Pre-Training" ] +DATASET_STAGE_MAP = { + "SFT": "sft", + "Pre-Training": "pt", + "Reward Modeling": "rm", + "PPO": "sft", + "DPO": "rm" +} + SUPPORTED_MODELS = { "LLaMA-7B": "huggyllama/llama-7b", "LLaMA-13B": "huggyllama/llama-13b", diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 965a690b..98fface4 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP DEFAULT_CACHE_DIR = "cache" @@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: return {} -def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: +def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]: dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - return gr.update(value=[], choices=list(dataset_info.keys())) + if stage: + dataset_stage = DATASET_STAGE_MAP[stage] + dataset_info = {key: value for key, value in dataset_info.items() + if ("stage" not in value) or value["stage"] == dataset_stage} + + return gr.update(value=[], choices=list(dataset_info.keys())) \ No newline at end of file diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index aab512ee..7b69944c 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_dataset, [dataset_dir], [dataset]) + training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) + dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) data_preview_btn.click( get_preview,