modify some style

This commit is contained in:
BUAADreamer 2024-04-25 21:58:18 +08:00
parent 2cab2d42fb
commit 2d4ded535f
6 changed files with 26 additions and 158 deletions

View File

@ -82,10 +82,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append(
{
"role": tag_mapping[message[dataset_attr.role_tag]],
"content": message[dataset_attr.content_tag],
}
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
@ -126,10 +123,7 @@ def convert_llava(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -
raise ValueError("Invalid role tag in {}.".format(messages))
aligned_messages.append(
{
"role": tag_mapping[message[dataset_attr.role_tag]],
"content": message[dataset_attr.content_tag],
}
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
@ -143,9 +137,7 @@ def convert_llava(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
@ -165,16 +157,10 @@ def align_dataset(
features = Features.from_dict(
{
"prompt": [
{
"role": {"dtype": "string", "_type": "Value"},
"content": {"dtype": "string", "_type": "Value"},
}
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{
"role": {"dtype": "string", "_type": "Value"},
"content": {"dtype": "string", "_type": "Value"},
}
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},

View File

@ -14,14 +14,11 @@ if TYPE_CHECKING:
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__)
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
@ -56,11 +53,7 @@ def preprocess_supervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
@ -154,12 +147,7 @@ def preprocess_multimodal_supervised_dataset(
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
tokenizer = processor.tokenizer
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"pixel_values": [],
}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
@ -284,10 +272,7 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("label_ids:\n{}".format(example["labels"]))
print(
"labels:\n{}".format(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, example["labels"])),
skip_special_tokens=False,
)
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
@ -320,33 +305,21 @@ def get_preprocess_and_print_func(
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
elif processor is not None:
preprocess_func = partial(
preprocess_multimodal_supervised_dataset,
processor=processor,
template=template,
data_args=data_args,
preprocess_multimodal_supervised_dataset, processor=processor, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
tokenizer=tokenizer,
template=template,
data_args=data_args,
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:

View File

@ -11,7 +11,6 @@ if TYPE_CHECKING:
from .formatter import SLOTS, Formatter
logger = get_logger(__name__)
@ -368,8 +367,7 @@ def get_template_and_fix_tokenizer(
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words),
replace_additional_special_tokens=False,
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
@ -393,7 +391,6 @@ _register_template(
),
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
@ -406,36 +403,26 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="atom",
format_user=StringFormatter(
slots=[
{"bos_token"},
"Human: {{content}}\n",
{"eos_token"},
{"bos_token"},
"Assistant:",
]
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
),
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
_register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
_register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
@ -444,13 +431,11 @@ _register_template(
force_system=True,
)
_register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
@ -462,7 +447,6 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@ -472,7 +456,6 @@ _register_template(
force_system=True,
)
_register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@ -480,40 +463,23 @@ _register_template(
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[
{"token": "<|observation|>"},
"\n",
"{{content}}",
{"token": "<|assistant|>"},
]
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
_register_template(
name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
slots=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{content}}",
]
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[
{"token": "<|observation|>"},
"\n",
"{{content}}",
{"token": "<|assistant|>"},
]
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
@ -523,7 +489,6 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -534,7 +499,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -546,14 +510,12 @@ _register_template(
replace_eos=True,
)
_register_template(
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="cohere",
format_user=StringFormatter(
@ -568,7 +530,6 @@ _register_template(
force_system=True,
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@ -576,7 +537,6 @@ _register_template(
force_system=True,
)
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -602,7 +562,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@ -610,7 +569,6 @@ _register_template(
force_system=True,
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@ -626,7 +584,6 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
@ -634,14 +591,12 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@ -649,14 +604,12 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@ -669,7 +622,6 @@ _register_template(
force_system=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@ -678,7 +630,6 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -695,7 +646,6 @@ _register_template(
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@ -712,7 +662,6 @@ _register_template(
),
)
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@ -720,7 +669,6 @@ _register_template(
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
_register_template(
name="llama3",
format_user=StringFormatter(
@ -732,10 +680,7 @@ _register_template(
]
),
format_system=StringFormatter(
slots=[
{"bos_token"},
"<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>",
]
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
format_observation=StringFormatter(
slots=[
@ -750,7 +695,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
@ -758,7 +702,6 @@ _register_template(
force_system=True,
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
@ -767,22 +710,14 @@ _register_template(
force_system=True,
)
_register_template(
name="openchat",
format_user=StringFormatter(
slots=[
"GPT4 Correct User: {{content}}",
{"eos_token"},
"GPT4 Correct Assistant:",
]
),
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@ -790,7 +725,6 @@ _register_template(
force_system=True,
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@ -802,7 +736,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -814,7 +747,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
@ -822,7 +754,6 @@ _register_template(
efficient_eos=True,
)
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
@ -833,7 +764,6 @@ _register_template(
force_system=True,
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@ -843,7 +773,6 @@ _register_template(
),
)
_register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
@ -854,13 +783,11 @@ _register_template(
),
)
_register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
@ -880,7 +807,6 @@ _register_template(
stop_words=["<|End|>"],
)
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -889,7 +815,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@ -898,7 +823,6 @@ _register_template(
replace_eos=True,
)
_register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
@ -907,7 +831,6 @@ _register_template(
default_system="You are a friendly chatbot who always responds in the style of a pirate",
)
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),

View File

@ -260,7 +260,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "sft_mm"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)

View File

@ -41,9 +41,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
}
def load_tokenizer(
model_args: "ModelArguments",
) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]:
def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]:
r"""
Loads pretrained tokenizer.

View File

@ -17,12 +17,7 @@ from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import (
DataArguments,
FinetuningArguments,
GeneratingArguments,
ModelArguments,
)
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_sft(
@ -36,14 +31,7 @@ def run_sft(
tokenizer_modules = load_tokenizer(model_args)
tokenizer = tokenizer_modules["tokenizer"]
processor = tokenizer_modules["processor"]
dataset = get_dataset(
tokenizer,
model_args,
data_args,
training_args,
stage="sft",
processor=processor,
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if training_args.predict_with_generate:
@ -54,7 +42,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=(8 if tokenizer.padding_side == "right" else None), # for shift short attention
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id),
)
@ -72,7 +60,7 @@ def run_sft(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=(ComputeMetrics(tokenizer) if training_args.predict_with_generate else None),
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args),
)