From 12043aab9c9ea2acaa3ff80232ce0451ccaa557d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 19 Jan 2024 21:44:32 +0800 Subject: [PATCH] fix #2249 --- src/llmtuner/data/loader.py | 9 ++++++++- src/llmtuner/data/template.py | 4 ++-- src/llmtuner/eval/evaluator.py | 8 +++++++- src/llmtuner/eval/template.py | 8 +------- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 87f42558..2ea00932 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,4 +1,5 @@ import os +import inspect from typing import TYPE_CHECKING, List, Literal, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk @@ -82,6 +83,11 @@ def load_single_dataset( except ImportError: raise ImportError("Please install modelscope via `pip install modelscope -U`") else: + if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 + kwargs = {"trust_remote_code": True} + else: + kwargs = {} + dataset = load_dataset( path=data_path, name=data_name, @@ -90,7 +96,8 @@ def load_single_dataset( split=data_args.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")) + streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + **kwargs ) if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 4243807b..253edbc5 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -32,8 +32,8 @@ class Template: self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], - system: str, - tools: str, + system: Optional[str] = None, + tools: Optional[str] = None, cutoff_len: Optional[int] = 1_000_000 ) -> Tuple[List[int], List[int]]: r""" diff --git a/src/llmtuner/eval/evaluator.py b/src/llmtuner/eval/evaluator.py index 1cb55b38..08bdb08c 100644 --- a/src/llmtuner/eval/evaluator.py +++ b/src/llmtuner/eval/evaluator.py @@ -4,6 +4,7 @@ import os import json import torch import numpy as np +import inspect from tqdm import tqdm, trange from typing import Any, Dict, List, Optional @@ -53,13 +54,18 @@ class Evaluator: pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) results = {} for subject in pbar: + if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 + kwargs = {"trust_remote_code": True} + else: + kwargs = {} + dataset = load_dataset( path=os.path.join(self.eval_args.task_dir, self.eval_args.task), name=subject, cache_dir=self.model_args.cache_dir, download_mode=self.eval_args.download_mode, token=self.model_args.hf_hub_token, - trust_remote_code=True + **kwargs ) pbar.set_postfix_str(categorys[subject]["name"]) inputs, outputs, labels = [], [], [] diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py index 5514e5d5..2c697510 100644 --- a/src/llmtuner/eval/template.py +++ b/src/llmtuner/eval/template.py @@ -27,8 +27,7 @@ class EvalTemplate: self, target_data: Dict[str, str], support_set: "Dataset", - subject_name: str, - use_history: bool + subject_name: str ) -> List[Dict[str, str]]: messages = [] for k in range(len(support_set)): @@ -39,12 +38,7 @@ class EvalTemplate: prompt, response = self.parse_example(target_data) messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.ASSISTANT, "content": response}) - messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] - - if not use_history: - messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}] - return messages