improve aligner

This commit is contained in:
hiyouga 2024-02-10 16:39:19 +08:00
parent 388b705a8d
commit 7d2dc83c5e
11 changed files with 127 additions and 112 deletions

View File

@ -174,6 +174,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa) - [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)

View File

@ -174,6 +174,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa) - [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)

View File

@ -11,7 +11,7 @@ If you are using a custom dataset, please provide your dataset definition in the
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)", "ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": { "columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction)", "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input)", "query": "the column name in the dataset containing the queries. (default: input)",
"response": "the column name in the dataset containing the responses. (default: output)", "response": "the column name in the dataset containing the responses. (default: output)",
@ -20,14 +20,14 @@ If you are using a custom dataset, please provide your dataset definition in the
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)" "tools": "the column name in the dataset containing the tool description. (default: None)"
}, },
"tags": { "tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)", "role_tag": "the key in the message represents the identity. (default: from)",
"content_tag": "the key in the message represents the content. (default: value)", "content_tag": "the key in the message represents the content. (default: value)",
"user_tag": "the value of the role_tag represents the user. (default: human)", "user_tag": "the value of the role_tag represents the user. (default: human)",
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)", "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)", "observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
"function_tag": "the value of the role_tag represents the function call. (default: function_call)", "function_tag": "the value of the role_tag represents the function call. (default: function_call)",
"system_tag": "the value of the role_tag represents the system prompt. (default: None) incompatible with system column" "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
} }
} }
``` ```

View File

@ -11,7 +11,7 @@
"folder": "Hugging Face 仓库的文件夹名称可选默认None", "folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集可选默认False", "ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": { "columns(可选)": {
"prompt": "数据集代表提示词的表头名称默认instruction", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output", "response": "数据集代表回答的表头名称默认output",
@ -20,13 +20,14 @@
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None" "tools": "数据集代表工具描述的表头名称默认None"
}, },
"tags": { "tags(可选,用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from", "role_tag": "消息中代表发送者身份的键名默认from",
"content_tag": "消息中代表文本内容的键名默认value", "content_tag": "消息中代表文本内容的键名默认value",
"user_tag": "消息中代表用户的 role_tag默认human", "user_tag": "消息中代表用户的 role_tag默认human",
"assistant_tag": "消息中代表助手的 role_tag默认gpt", "assistant_tag": "消息中代表助手的 role_tag默认gpt",
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation", "observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call" "function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system 列)"
} }
} }
``` ```

View File

@ -15,9 +15,6 @@
"file_name": "alpaca_gpt4_data_zh.json", "file_name": "alpaca_gpt4_data_zh.json",
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845" "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
}, },
"alpaca-gpt4_de": {
"hf_hub_url": "mayflowergmbh/alpaca-gpt4_de"
},
"self_cognition": { "self_cognition": {
"file_name": "self_cognition.json", "file_name": "self_cognition.json",
"file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67" "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67"
@ -42,9 +39,6 @@
"history": "history" "history": "history"
} }
}, },
"oasst_de": {
"hf_hub_url": "mayflowergmbh/oasst_de"
},
"lima": { "lima": {
"file_name": "lima.json", "file_name": "lima.json",
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37", "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
@ -126,44 +120,8 @@
"system": "system_prompt" "system": "system_prompt"
} }
}, },
"slimorca": { "slimorca": {
"hf_hub_url": "Open-Orca/SlimOrca", "hf_hub_url": "Open-Orca/SlimOrca"
"formatting": "sharegpt",
"columns": {
"messages": "conversations"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt",
"system_tag": "system"
}
},
"intel_orca_dpo_pairs_de" : {
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
"ranking": true
},
"airoboros-3.0_de": {
"hf_hub_url": "mayflowergmbh/airoboros-3.0_de"
},
"booksum_de": {
"hf_hub_url": "mayflowergmbh/booksum_de"
},
"dolphin_de": {
"hf_hub_url": "mayflowergmbh/dolphin_de"
},
"wiki_qa_de": {
"hf_hub_url": "mayflowergmbh/wiki_qa_de"
},
"evol-instruct_de": {
"hf_hub_url": "mayflowergmbh/evol-instruct_de"
},
"openschnabeltier_de": {
"hf_hub_url": "mayflowergmbh/openschnabeltier_de"
},
"dolly-15k_de": {
"hf_hub_url": "mayflowergmbh/dolly-15k_de"
}, },
"mathinstruct": { "mathinstruct": {
"hf_hub_url": "TIGER-Lab/MathInstruct", "hf_hub_url": "TIGER-Lab/MathInstruct",
@ -180,6 +138,13 @@
"response": "target" "response": "target"
} }
}, },
"wikiqa": {
"hf_hub_url": "wiki_qa",
"columns": {
"prompt": "question",
"response": "answer"
}
},
"webqa": { "webqa": {
"hf_hub_url": "suolyer/webqa", "hf_hub_url": "suolyer/webqa",
"ms_hub_url": "AI-ModelScope/webqa", "ms_hub_url": "AI-ModelScope/webqa",
@ -193,7 +158,8 @@
"ms_hub_url": "AI-ModelScope/webnovel_cn" "ms_hub_url": "AI-ModelScope/webnovel_cn"
}, },
"nectar_sft": { "nectar_sft": {
"hf_hub_url": "mlinmg/SFT-Nectar" "hf_hub_url": "mlinmg/SFT-Nectar",
"ms_hub_url": "AI-ModelScope/SFT-Nectar"
}, },
"deepctrl": { "deepctrl": {
"ms_hub_url": "deepctrl/deepctrl-sft-data" "ms_hub_url": "deepctrl/deepctrl-sft-data"
@ -229,9 +195,6 @@
}, },
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
"ultrachat_chat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
},
"agent_instruct": { "agent_instruct": {
"hf_hub_url": "THUDM/AgentInstruct", "hf_hub_url": "THUDM/AgentInstruct",
"ms_hub_url": "ZhipuAI/AgentInstruct", "ms_hub_url": "ZhipuAI/AgentInstruct",
@ -253,8 +216,36 @@
}, },
"evol_instruct": { "evol_instruct": {
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k", "hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
"ms_hub_url": "AI-ModelScope/WizardLM_evol_instruct_V2_196k",
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
"oasst_de": {
"hf_hub_url": "mayflowergmbh/oasst_de"
},
"dolly_15k_de": {
"hf_hub_url": "mayflowergmbh/dolly-15k_de"
},
"alpaca-gpt4_de": {
"hf_hub_url": "mayflowergmbh/alpaca-gpt4_de"
},
"openschnabeltier_de": {
"hf_hub_url": "mayflowergmbh/openschnabeltier_de"
},
"evol_instruct_de": {
"hf_hub_url": "mayflowergmbh/evol-instruct_de"
},
"dolphin_de": {
"hf_hub_url": "mayflowergmbh/dolphin_de"
},
"booksum_de": {
"hf_hub_url": "mayflowergmbh/booksum_de"
},
"airoboros_de": {
"hf_hub_url": "mayflowergmbh/airoboros-3.0_de"
},
"ultrachat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
},
"hh_rlhf_en": { "hh_rlhf_en": {
"script_url": "hh_rlhf_en", "script_url": "hh_rlhf_en",
"columns": { "columns": {
@ -298,6 +289,11 @@
}, },
"nectar_rm": { "nectar_rm": {
"hf_hub_url": "mlinmg/RLAIF-Nectar", "hf_hub_url": "mlinmg/RLAIF-Nectar",
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
"ranking": true
},
"orca_dpo_de" : {
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
"ranking": true "ranking": true
}, },
"wiki_demo": { "wiki_demo": {
@ -329,6 +325,7 @@
}, },
"wikipedia_en": { "wikipedia_en": {
"hf_hub_url": "olm/olm-wikipedia-20221220", "hf_hub_url": "olm/olm-wikipedia-20221220",
"ms_hub_url": "AI-ModelScope/olm-wikipedia-20221220",
"columns": { "columns": {
"prompt": "text" "prompt": "text"
} }
@ -342,6 +339,7 @@
}, },
"pile": { "pile": {
"hf_hub_url": "EleutherAI/pile", "hf_hub_url": "EleutherAI/pile",
"ms_hub_url": "AI-ModelScope/pile",
"columns": { "columns": {
"prompt": "text" "prompt": "text"
}, },
@ -349,6 +347,7 @@
}, },
"skypile": { "skypile": {
"hf_hub_url": "Skywork/SkyPile-150B", "hf_hub_url": "Skywork/SkyPile-150B",
"ms_hub_url": "AI-ModelScope/SkyPile-150B",
"columns": { "columns": {
"prompt": "text" "prompt": "text"
} }

View File

@ -49,40 +49,32 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
dataset_attr.function_tag: Role.FUNCTION, dataset_attr.function_tag: Role.FUNCTION,
dataset_attr.system_tag: Role.SYSTEM, dataset_attr.system_tag: Role.SYSTEM,
} }
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): for i, messages in enumerate(examples[dataset_attr.messages]):
if len(messages) <= 1: if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue continue
prompt = [] aligned_messages = []
response = []
n_sys = 0
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if dataset_attr.system_tag and message[dataset_attr.role_tag] == dataset_attr.system_tag: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
outputs["system"].append(message[dataset_attr.content_tag])
n_sys = 1
continue
if (turn_idx - n_sys) % 2 == 0:
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag in {}.".format(messages)) raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append( aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
) )
if len(prompt) % 2 == 1: outputs["prompt"].append(aligned_messages[:-1])
# Last message was neither from assistant nor function outputs["response"].append(aligned_messages[-1:])
prompt.pop(-1) outputs["system"].append(system)
last_message = prompt.pop(-1)
response.append(last_message)
outputs["prompt"].append(prompt)
outputs["response"].append(response)
if n_sys == 0:
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs return outputs
@ -93,8 +85,8 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
prompt: [{"role": "user", "content": "..."}] prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." system: "..."
tools: "..." tools: "..."
""" """

View File

@ -30,6 +30,7 @@ def load_single_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
): ):
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
@ -60,7 +61,7 @@ def load_single_dataset(
if data_path is None: if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.") raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.dataset_sha1) checksum(data_files, dataset_attr.file_sha1)
else: else:
raise NotImplementedError raise NotImplementedError
@ -157,7 +158,7 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
all_datasets = [] all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
@ -185,6 +186,6 @@ def get_dataset(
try: try:
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
raise RuntimeError("Empty dataset!") raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset return dataset

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
@ -13,38 +13,44 @@ if TYPE_CHECKING:
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
r"""
Dataset attributes.
"""
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None """ extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: Optional[bool] = False ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
""" columns """
system: Optional[str] = None system: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction" prompt: Optional[str] = "instruction"
query: Optional[str] = "input" query: Optional[str] = "input"
response: Optional[str] = "output" response: Optional[str] = "output"
history: Optional[str] = None history: Optional[str] = None
""" columns for the sharegpt format """
messages: Optional[str] = "conversations" messages: Optional[str] = "conversations"
tools: Optional[str] = None tools: Optional[str] = None
""" tags for the sharegpt format """
role_tag: Optional[str] = "from" role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value" content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human" user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt" assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation" observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call" function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = None system_tag: Optional[str] = "system"
assert system_tag is None or system is None, f"Can not provide both system message (system_tag={system_tag}) and system column(system={system})"
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else [] dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
@ -77,30 +83,36 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else: else:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None),
)
dataset_attr.subset = dataset_info[name].get("subset", None) dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.folder = dataset_info[name].get("folder", None) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names = ["prompt", "query", "response", "history"] column_names.extend(["prompt", "query", "response", "history"])
else: else:
column_names = ["messages", "tools"] column_names.extend(["messages", "tools"])
column_names += ["system"]
for column_name in column_names: for column_name in column_names:
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None)) dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]: tag_names = (
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None)) "role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)

View File

@ -247,7 +247,7 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if is_oov: if is_oov:
logger.warning("New token is added, you must enable `resize_vocab` to activate it.") logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(

View File

@ -19,9 +19,9 @@ logger = get_logger(__name__)
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system"
OBSERVATION = "observation" OBSERVATION = "observation"
FUNCTION = "function" FUNCTION = "function"
SYSTEM = "system"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:

View File

@ -67,7 +67,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Only LoRA method has adapters.") raise ValueError("Adapter is only valid for the LoRA method.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
@ -125,6 +125,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.") logger.warning("We recommend enable `upcast_layernorm` in quantized training.")