LLaMA-Factory-Mirror/data/belle_multiturn/belle_multiturn.py

79 lines
2.8 KiB
Python
Raw Normal View History

import os
2023-06-16 20:01:16 +08:00
import json
import datasets
2024-03-20 20:09:06 +08:00
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
2023-06-16 20:01:16 +08:00
_DESCRIPTION = "BELLE multiturn chat dataset."
_CITATION = """\
@article{belle2023exploring,
title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases},
author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
journal={arXiv preprint arXiv:2303.14742},
year={2023}
}
"""
2024-03-20 20:09:06 +08:00
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
2023-06-16 20:01:16 +08:00
_LICENSE = "gpl-3.0"
2024-03-20 20:09:06 +08:00
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
2023-06-16 20:01:16 +08:00
class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
2023-11-09 15:53:23 +08:00
def _info(self):
2023-06-16 20:01:16 +08:00
features = datasets.Features({
2023-11-16 02:08:04 +08:00
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
2023-06-16 20:01:16 +08:00
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
)
2023-11-09 15:53:23 +08:00
def _split_generators(self, dl_manager: datasets.DownloadManager):
2023-06-16 20:01:16 +08:00
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
2023-11-09 15:53:23 +08:00
def _generate_examples(self, filepath: str):
2023-06-16 20:01:16 +08:00
with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
2023-11-16 02:08:04 +08:00
conversations = []
2023-06-16 20:01:16 +08:00
prompt = data["instruction"].strip()
response = data["output"].strip()
assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip()
2023-11-16 02:08:04 +08:00
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
2023-06-16 20:01:16 +08:00
while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:")
human_idx = prompt.rfind("Human:")
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
2023-11-16 02:08:04 +08:00
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
2023-06-16 20:01:16 +08:00
else:
break
prompt = prompt[:human_idx].strip()
2023-11-16 02:08:04 +08:00
yield key, {"conversations": conversations}