diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index ab6b5de4..8e2e04bf 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -1,6 +1,6 @@ import json import os -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Tuple import gradio as gr @@ -32,39 +32,36 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]: return gr.Button(interactive=False) - local_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) - if (os.path.isfile(local_path) - or (os.path.isdir(local_path) and len(os.listdir(local_path)) != 0)): + data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) + if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)): return gr.Button(interactive=True) else: return gr.Button(interactive=False) -def load_single_data(data_file_path): - with open(os.path.join(data_file_path), "r", encoding="utf-8") as f: - if data_file_path.endswith(".json"): - data = json.load(f) - elif data_file_path.endswith(".jsonl"): - data = [json.loads(line) for line in f] +def _load_data_file(file_path: str) -> List[Any]: + with open(file_path, "r", encoding="utf-8") as f: + if file_path.endswith(".json"): + return json.load(f) + elif file_path.endswith(".jsonl"): + return [json.loads(line) for line in f] else: - data = [line for line in f] # noqa: C416 - return data + return list(f) def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: dataset_info = json.load(f) - data_file: str = dataset_info[dataset[0]]["file_name"] - local_path = os.path.join(dataset_dir, data_file) - if os.path.isdir(local_path): - data = [] - for file_name in os.listdir(local_path): - data.extend(load_single_data(os.path.join(local_path, file_name))) + data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) + if os.path.isfile(data_path): + data = _load_data_file(data_path) else: - data = load_single_data(local_path) + data = [] + for file_name in os.listdir(data_path): + data.extend(_load_data_file(os.path.join(data_path, file_name))) - return len(data), data[PAGE_SIZE * page_index: PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) + return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: