From a5a248d569f8bf97cb9be71221783d97c666583c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 1 Dec 2023 17:13:22 +0800 Subject: [PATCH] fix err hint --- src/llmtuner/hparams/data_args.py | 11 +++++++---- src/llmtuner/webui/common.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 0b74c3cb..cea89198 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -4,6 +4,9 @@ from typing import List, Literal, Optional from dataclasses import dataclass, field +DATA_CONFIG = "dataset_info.json" + + @dataclass class DatasetAttr: @@ -130,11 +133,11 @@ class DataArguments: self.seed = seed dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] try: - with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: + with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f: dataset_info = json.load(f) - except Exception: + except Exception as err: if self.dataset is not None: - raise ValueError("Cannot find dataset_info.json in `dataset_dir`.") + raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err))) dataset_info = None prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] @@ -147,7 +150,7 @@ class DataArguments: self.dataset_list: List[DatasetAttr] = [] for i, name in enumerate(dataset_names): if name not in dataset_info: - raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) + raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) if "hf_hub_url" in dataset_info[name]: dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 8788754a..40f10c6a 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -11,14 +11,19 @@ from transformers.utils import ( ADAPTER_SAFE_WEIGHTS_NAME ) -from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES +from llmtuner.extras.constants import ( + DEFAULT_MODULE, + DEFAULT_TEMPLATE, + SUPPORTED_MODELS, + TRAINING_STAGES +) +from llmtuner.hparams.data_args import DATA_CONFIG DEFAULT_CACHE_DIR = "cache" DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user.config" -DATA_CONFIG = "dataset_info.json" CKPT_NAMES = [ WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -94,7 +99,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: return json.load(f) except Exception as err: - print("Cannot open {} in {} due to {}.".format(DATA_CONFIG, dataset_dir, str(err))) + print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err))) return {}