This commit is contained in:
hiyouga 2024-01-19 21:44:32 +08:00
parent b6ec112beb
commit 12043aab9c
4 changed files with 18 additions and 11 deletions

View File

@ -1,4 +1,5 @@
import os
import inspect
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
@ -82,6 +83,11 @@ def load_single_dataset(
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=data_path,
name=data_name,
@ -90,7 +96,8 @@ def load_single_dataset(
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True

View File

@ -32,8 +32,8 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000
) -> Tuple[List[int], List[int]]:
r"""

View File

@ -4,6 +4,7 @@ import os
import json
import torch
import numpy as np
import inspect
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
@ -53,13 +54,18 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token,
trust_remote_code=True
**kwargs
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []

View File

@ -27,8 +27,7 @@ class EvalTemplate:
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
subject_name: str
) -> List[Dict[str, str]]:
messages = []
for k in range(len(support_set)):
@ -39,12 +38,7 @@ class EvalTemplate:
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history:
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
return messages