Update aligner.py

This commit is contained in:
hoshi-hiyouga 2024-04-26 03:48:34 +08:00 committed by GitHub
parent f62cadb258
commit fcd09112d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 9 deletions

View File

@ -1,3 +1,4 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
@ -13,8 +14,10 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
@ -44,11 +47,18 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["images"].append([])
outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
@ -84,7 +94,11 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(examples[dataset_attr.images][i] if dataset_attr.images else [])
outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)
return outputs
@ -97,12 +111,13 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
tools: "...",
images: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
@ -115,7 +130,7 @@ def align_dataset(
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": {"feature": {"_type": "Image"}, "_type": "Sequence"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}