1. change the task name format

2. delete split param in data_args.py
This commit is contained in:
codingma 2024-07-15 09:55:33 +08:00
parent 99ab7a8c1c
commit 645211dc01
3 changed files with 8 additions and 10 deletions

View File

@ -6,8 +6,7 @@ adapter_name_or_path: saves/llama3-8b/lora/sft
finetuning_type: lora
### dataset
task: mmlu
split: test
task: mmlu_test
template: fewshot
lang: en
n_shot: 5

View File

@ -73,8 +73,11 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
task = self.eval_args.task.split("_")[0]
split = self.eval_args.task.split("_")[1]
mapping = cached_file(
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
path_or_repo_id=os.path.join(self.eval_args.task_dir, task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
@ -88,7 +91,7 @@ class Evaluator:
results = {}
for subject in pbar:
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
path=os.path.join(self.eval_args.task_dir, task),
name=subject,
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
@ -97,12 +100,12 @@ class Evaluator:
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
support_set = (
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
)
messages = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
target_data=dataset[split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
)

View File

@ -41,10 +41,6 @@ class DataArguments:
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
split: str = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."},
)
cutoff_len: int = field(
default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},