add dataset stage and filter dataset when stage chosen in webui

This commit is contained in:
codemayq 2023-08-23 18:54:23 +08:00
parent 1c702ad538
commit c0e4d1e81b
4 changed files with 84 additions and 37 deletions

View File

@ -1,23 +1,28 @@
{ {
"alpaca_en": { "alpaca_en": {
"file_name": "alpaca_data_en_52k.json", "file_name": "alpaca_data_en_52k.json",
"file_sha1": "607f94a7f581341e59685aef32f531095232cf23" "file_sha1": "607f94a7f581341e59685aef32f531095232cf23",
"stage": "sft"
}, },
"alpaca_zh": { "alpaca_zh": {
"file_name": "alpaca_data_zh_51k.json", "file_name": "alpaca_data_zh_51k.json",
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311" "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311",
"stage": "sft"
}, },
"alpaca_gpt4_en": { "alpaca_gpt4_en": {
"file_name": "alpaca_gpt4_data_en.json", "file_name": "alpaca_gpt4_data_en.json",
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a" "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a",
"stage": "sft"
}, },
"alpaca_gpt4_zh": { "alpaca_gpt4_zh": {
"file_name": "alpaca_gpt4_data_zh.json", "file_name": "alpaca_gpt4_data_zh.json",
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845" "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845",
"stage": "sft"
}, },
"self_cognition": { "self_cognition": {
"file_name": "self_cognition.json", "file_name": "self_cognition.json",
"file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67" "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67",
"stage": "sft"
}, },
"oaast_sft": { "oaast_sft": {
"file_name": "oaast_sft.json", "file_name": "oaast_sft.json",
@ -27,7 +32,8 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"oaast_sft_zh": { "oaast_sft_zh": {
"file_name": "oaast_sft_zh.json", "file_name": "oaast_sft_zh.json",
@ -37,7 +43,8 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"sharegpt_zh": { "sharegpt_zh": {
"file_name": "sharegpt_zh_27k.json", "file_name": "sharegpt_zh_27k.json",
@ -47,7 +54,8 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"lima": { "lima": {
"file_name": "lima.json", "file_name": "lima.json",
@ -57,7 +65,8 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"example": { "example": {
"script_url": "example_dataset", "script_url": "example_dataset",
@ -66,25 +75,32 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"guanaco": { "guanaco": {
"hf_hub_url": "JosephusCheung/GuanacoDataset" "hf_hub_url": "JosephusCheung/GuanacoDataset",
"stage": "sft"
}, },
"belle_0.5m": { "belle_0.5m": {
"hf_hub_url": "BelleGroup/train_0.5M_CN" "hf_hub_url": "BelleGroup/train_0.5M_CN",
"stage": "sft"
}, },
"belle_1m": { "belle_1m": {
"hf_hub_url": "BelleGroup/train_1M_CN" "hf_hub_url": "BelleGroup/train_1M_CN",
"stage": "sft"
}, },
"belle_2m": { "belle_2m": {
"hf_hub_url": "BelleGroup/train_2M_CN" "hf_hub_url": "BelleGroup/train_2M_CN",
"stage": "sft"
}, },
"belle_dialog": { "belle_dialog": {
"hf_hub_url": "BelleGroup/generated_chat_0.4M" "hf_hub_url": "BelleGroup/generated_chat_0.4M",
"stage": "sft"
}, },
"belle_math": { "belle_math": {
"hf_hub_url": "BelleGroup/school_math_0.25M" "hf_hub_url": "BelleGroup/school_math_0.25M",
"stage": "sft"
}, },
"belle_multiturn": { "belle_multiturn": {
"script_url": "belle_multiturn", "script_url": "belle_multiturn",
@ -93,7 +109,8 @@
"query": "", "query": "",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"firefly": { "firefly": {
"hf_hub_url": "YeungNLP/firefly-train-1.1M", "hf_hub_url": "YeungNLP/firefly-train-1.1M",
@ -102,13 +119,16 @@
"query": "", "query": "",
"response": "target", "response": "target",
"history": "" "history": ""
} },
"stage": "sft"
}, },
"codealpaca": { "codealpaca": {
"hf_hub_url": "sahil2801/CodeAlpaca-20k" "hf_hub_url": "sahil2801/CodeAlpaca-20k",
"stage": "sft"
}, },
"alpaca_cot": { "alpaca_cot": {
"hf_hub_url": "QingyiSi/Alpaca-CoT" "hf_hub_url": "QingyiSi/Alpaca-CoT",
"stage": "sft"
}, },
"webqa": { "webqa": {
"hf_hub_url": "suolyer/webqa", "hf_hub_url": "suolyer/webqa",
@ -117,7 +137,8 @@
"query": "", "query": "",
"response": "output", "response": "output",
"history": "" "history": ""
} },
"stage": "sft"
}, },
"ultra_chat": { "ultra_chat": {
"script_url": "ultra_chat", "script_url": "ultra_chat",
@ -126,18 +147,22 @@
"query": "", "query": "",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "sft"
}, },
"novel_tokens512_50k": { "novel_tokens512_50k": {
"hf_hub_url": "zxbsmk/webnovel_cn" "hf_hub_url": "zxbsmk/webnovel_cn",
"stage": "sft"
}, },
"comparison_gpt4_en": { "comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json", "file_name": "comparison_gpt4_data_en.json",
"file_sha1": "96fa18313544e22444fe20eead7754b17da452ae" "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae",
"stage": "rm"
}, },
"comparison_gpt4_zh": { "comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json", "file_name": "comparison_gpt4_data_zh.json",
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd" "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
"stage": "rm"
}, },
"hh_rlhf_en": { "hh_rlhf_en": {
"script_url": "hh_rlhf_en", "script_url": "hh_rlhf_en",
@ -146,7 +171,8 @@
"query": "", "query": "",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "rm"
}, },
"oaast_rm": { "oaast_rm": {
"file_name": "oaast_rm.json", "file_name": "oaast_rm.json",
@ -156,7 +182,8 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "rm"
}, },
"oaast_rm_zh": { "oaast_rm_zh": {
"file_name": "oaast_rm_zh.json", "file_name": "oaast_rm_zh.json",
@ -166,7 +193,8 @@
"query": "input", "query": "input",
"response": "output", "response": "output",
"history": "history" "history": "history"
} },
"stage": "rm"
}, },
"wiki_demo": { "wiki_demo": {
"file_name": "wiki_demo.txt", "file_name": "wiki_demo.txt",
@ -176,7 +204,8 @@
"query": "", "query": "",
"response": "", "response": "",
"history": "" "history": ""
} },
"stage": "pt"
}, },
"refinedweb": { "refinedweb": {
"hf_hub_url": "tiiuae/falcon-refinedweb", "hf_hub_url": "tiiuae/falcon-refinedweb",
@ -185,7 +214,8 @@
"query": "", "query": "",
"response": "", "response": "",
"history": "" "history": ""
} },
"stage": "pt"
}, },
"starcoder": { "starcoder": {
"hf_hub_url": "bigcode/starcoderdata", "hf_hub_url": "bigcode/starcoderdata",
@ -194,7 +224,8 @@
"query": "", "query": "",
"response": "", "response": "",
"history": "" "history": ""
} },
"stage": "pt"
}, },
"wikipedia_en": { "wikipedia_en": {
"hf_hub_url": "olm/olm-wikipedia-20221220", "hf_hub_url": "olm/olm-wikipedia-20221220",
@ -203,7 +234,8 @@
"query": "", "query": "",
"response": "", "response": "",
"history": "" "history": ""
} },
"stage": "pt"
}, },
"wikipedia_zh": { "wikipedia_zh": {
"hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered", "hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered",
@ -212,6 +244,7 @@
"query": "", "query": "",
"response": "", "response": "",
"history": "" "history": ""
} },
"stage": "pt"
} }
} }

View File

@ -18,6 +18,14 @@ STAGES = [
"Pre-Training" "Pre-Training"
] ]
DATASET_STAGE_MAP = {
"SFT": "sft",
"Pre-Training": "pt",
"Reward Modeling": "rm",
"PPO": "sft",
"DPO": "rm"
}
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b", "LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b", "LLaMA-13B": "huggyllama/llama-13b",

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {} return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
return gr.update(value=[], choices=list(dataset_info.keys())) if stage:
dataset_stage = DATASET_STAGE_MAP[stage]
dataset_info = {key: value for key, value in dataset_info.items()
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
data_preview_btn.click( data_preview_btn.click(
get_preview, get_preview,