diff --git a/examples/train_lora/llama3_lora_eval.yaml b/examples/train_lora/llama3_lora_eval.yaml index 6fcfd6ef..ad7f938b 100644 --- a/examples/train_lora/llama3_lora_eval.yaml +++ b/examples/train_lora/llama3_lora_eval.yaml @@ -6,8 +6,7 @@ adapter_name_or_path: saves/llama3-8b/lora/sft finetuning_type: lora ### dataset -task: mmlu -split: test +task: mmlu_test template: fewshot lang: en n_shot: 5 diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index b5269906..c5661997 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -73,8 +73,11 @@ class Evaluator: return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] def eval(self) -> None: + task = self.eval_args.task.split("_")[0] + split = self.eval_args.task.split("_")[1] + mapping = cached_file( - path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task), + path_or_repo_id=os.path.join(self.eval_args.task_dir, task), filename="mapping.json", cache_dir=self.model_args.cache_dir, token=self.model_args.hf_hub_token, @@ -88,7 +91,7 @@ class Evaluator: results = {} for subject in pbar: dataset = load_dataset( - path=os.path.join(self.eval_args.task_dir, self.eval_args.task), + path=os.path.join(self.eval_args.task_dir, task), name=subject, cache_dir=self.model_args.cache_dir, download_mode=self.eval_args.download_mode, @@ -97,12 +100,12 @@ class Evaluator: ) pbar.set_postfix_str(categorys[subject]["name"]) inputs, outputs, labels = [], [], [] - for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False): + for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False): support_set = ( dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) ) messages = self.eval_template.format_example( - target_data=dataset[self.data_args.split][i], + target_data=dataset[split][i], support_set=support_set, subject_name=categorys[subject]["name"], ) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index f483099d..9ae15d2d 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -41,10 +41,6 @@ class DataArguments: default="data", metadata={"help": "Path to the folder containing the datasets."}, ) - split: str = field( - default="train", - metadata={"help": "Which dataset split to use for training and evaluation."}, - ) cutoff_len: int = field( default=1024, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},