From 36d7a7596688d40524ca91f662acb391e7ce1817 Mon Sep 17 00:00:00 2001 From: Mark Mueller Date: Thu, 8 Feb 2024 08:28:32 -0800 Subject: [PATCH] SlimOrca aligner --- src/llmtuner/data/aligner.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index 8144141c..5140f9d8 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -53,28 +53,32 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" if len(messages) == 0: continue + n_sys = 0 prompt = [] response = [] for turn_idx, message in enumerate(messages): - if turn_idx % 2 == 0: - accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag] - else: - accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag] + accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag, dataset_attr.assistant_tag, dataset_attr.function_tag] - if message[dataset_attr.role_tag] not in accept_tags: + if message[dataset_attr.role_tag] == "system": + outputs["system"].append(message[dataset_attr.content_tag]) + n_sys += 1 + elif message[dataset_attr.role_tag] not in accept_tags: + print("sytem attr", dataset_attr.system) + print("accepted tags", accept_tags) raise ValueError("Invalid role tag in {}.".format(messages)) - - prompt.append( - {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} - ) + else: + prompt.append( + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + ) last_message = prompt.pop(-1) response.append(last_message) outputs["prompt"].append(prompt) outputs["response"].append(response) - outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") + if n_sys == 0: + outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") - + assert n_sys <= 1 return outputs