fix err hint

This commit is contained in:
hiyouga 2023-12-01 17:13:22 +08:00
parent a51b8ec620
commit a5a248d569
2 changed files with 15 additions and 7 deletions

View File

@ -4,6 +4,9 @@ from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
DATA_CONFIG = "dataset_info.json"
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
@ -130,11 +133,11 @@ class DataArguments:
self.seed = seed self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try: try:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception: except Exception as err:
if self.dataset is not None: if self.dataset is not None:
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.") raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
@ -147,7 +150,7 @@ class DataArguments:
self.dataset_list: List[DatasetAttr] = [] self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names): for i, name in enumerate(dataset_names):
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])

View File

@ -11,14 +11,19 @@ from transformers.utils import (
ADAPTER_SAFE_WEIGHTS_NAME ADAPTER_SAFE_WEIGHTS_NAME
) )
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES from llmtuner.extras.constants import (
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
SUPPORTED_MODELS,
TRAINING_STAGES
)
from llmtuner.hparams.data_args import DATA_CONFIG
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves" DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config" USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
CKPT_NAMES = [ CKPT_NAMES = [
WEIGHTS_NAME, WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
@ -94,7 +99,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except Exception as err: except Exception as err:
print("Cannot open {} in {} due to {}.".format(DATA_CONFIG, dataset_dir, str(err))) print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
return {} return {}