fix err hint
This commit is contained in:
parent
a51b8ec620
commit
a5a248d569
|
@ -4,6 +4,9 @@ from typing import List, Literal, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
DATA_CONFIG = "dataset_info.json"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetAttr:
|
class DatasetAttr:
|
||||||
|
|
||||||
|
@ -130,11 +133,11 @@ class DataArguments:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception:
|
except Exception as err:
|
||||||
if self.dataset is not None:
|
if self.dataset is not None:
|
||||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||||
|
@ -147,7 +150,7 @@ class DataArguments:
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
self.dataset_list: List[DatasetAttr] = []
|
||||||
for i, name in enumerate(dataset_names):
|
for i, name in enumerate(dataset_names):
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
if "hf_hub_url" in dataset_info[name]:
|
if "hf_hub_url" in dataset_info[name]:
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
|
|
|
@ -11,14 +11,19 @@ from transformers.utils import (
|
||||||
ADAPTER_SAFE_WEIGHTS_NAME
|
ADAPTER_SAFE_WEIGHTS_NAME
|
||||||
)
|
)
|
||||||
|
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
from llmtuner.extras.constants import (
|
||||||
|
DEFAULT_MODULE,
|
||||||
|
DEFAULT_TEMPLATE,
|
||||||
|
SUPPORTED_MODELS,
|
||||||
|
TRAINING_STAGES
|
||||||
|
)
|
||||||
|
from llmtuner.hparams.data_args import DATA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
DEFAULT_DATA_DIR = "data"
|
DEFAULT_DATA_DIR = "data"
|
||||||
DEFAULT_SAVE_DIR = "saves"
|
DEFAULT_SAVE_DIR = "saves"
|
||||||
USER_CONFIG = "user.config"
|
USER_CONFIG = "user.config"
|
||||||
DATA_CONFIG = "dataset_info.json"
|
|
||||||
CKPT_NAMES = [
|
CKPT_NAMES = [
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
|
@ -94,7 +99,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
print("Cannot open {} in {} due to {}.".format(DATA_CONFIG, dataset_dir, str(err)))
|
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue