SlimOrca aligner

This commit is contained in:
Mark Mueller 2024-02-08 08:28:32 -08:00
parent d0daaa01f9
commit 36d7a75966
1 changed files with 15 additions and 11 deletions

View File

@ -53,28 +53,32 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
if len(messages) == 0: if len(messages) == 0:
continue continue
n_sys = 0
prompt = [] prompt = []
response = [] response = []
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if turn_idx % 2 == 0: accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag, dataset_attr.assistant_tag, dataset_attr.function_tag]
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags: if message[dataset_attr.role_tag] == "system":
outputs["system"].append(message[dataset_attr.content_tag])
n_sys += 1
elif message[dataset_attr.role_tag] not in accept_tags:
print("sytem attr", dataset_attr.system)
print("accepted tags", accept_tags)
raise ValueError("Invalid role tag in {}.".format(messages)) raise ValueError("Invalid role tag in {}.".format(messages))
else:
prompt.append( prompt.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
) )
last_message = prompt.pop(-1) last_message = prompt.pop(-1)
response.append(last_message) response.append(last_message)
outputs["prompt"].append(prompt) outputs["prompt"].append(prompt)
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") if n_sys == 0:
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
assert n_sys <= 1
return outputs return outputs