From 2d4ded535faa44b460f88028d49e4b8c8b430db5 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Thu, 25 Apr 2024 21:58:18 +0800 Subject: [PATCH] modify some style --- src/llmtuner/data/aligner.py | 24 ++----- src/llmtuner/data/preprocess.py | 43 +++--------- src/llmtuner/data/template.py | 91 ++----------------------- src/llmtuner/hparams/finetuning_args.py | 2 +- src/llmtuner/model/loader.py | 4 +- src/llmtuner/train/sft/workflow.py | 20 ++---- 6 files changed, 26 insertions(+), 158 deletions(-) diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index 9d440aff..17b9fc6d 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -82,10 +82,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" raise ValueError("Invalid role tag in {}.".format(messages)) aligned_messages.append( - { - "role": tag_mapping[message[dataset_attr.role_tag]], - "content": message[dataset_attr.content_tag], - } + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} ) outputs["prompt"].append(aligned_messages[:-1]) @@ -126,10 +123,7 @@ def convert_llava(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") - raise ValueError("Invalid role tag in {}.".format(messages)) aligned_messages.append( - { - "role": tag_mapping[message[dataset_attr.role_tag]], - "content": message[dataset_attr.content_tag], - } + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} ) outputs["prompt"].append(aligned_messages[:-1]) @@ -143,9 +137,7 @@ def convert_llava(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") - def align_dataset( - dataset: Union["Dataset", "IterableDataset"], - dataset_attr: "DatasetAttr", - data_args: "DataArguments", + dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" ) -> Union["Dataset", "IterableDataset"]: r""" Aligned dataset: @@ -165,16 +157,10 @@ def align_dataset( features = Features.from_dict( { "prompt": [ - { - "role": {"dtype": "string", "_type": "Value"}, - "content": {"dtype": "string", "_type": "Value"}, - } + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} ], "response": [ - { - "role": {"dtype": "string", "_type": "Value"}, - "content": {"dtype": "string", "_type": "Value"}, - } + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} ], "system": {"dtype": "string", "_type": "Value"}, "tools": {"dtype": "string", "_type": "Value"}, diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 1c8c64a6..51af8060 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -14,14 +14,11 @@ if TYPE_CHECKING: from ..hparams import DataArguments from .template import Template - logger = get_logger(__name__) def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", - data_args: "DataArguments", + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] @@ -56,11 +53,7 @@ def preprocess_supervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: @@ -154,12 +147,7 @@ def preprocess_multimodal_supervised_dataset( # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. tokenizer = processor.tokenizer - model_inputs = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "pixel_values": [], - } + model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": []} for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: @@ -284,10 +272,7 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: " print("label_ids:\n{}".format(example["labels"])) print( "labels:\n{}".format( - tokenizer.decode( - list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), - skip_special_tokens=False, - ) + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) ) ) @@ -320,33 +305,21 @@ def get_preprocess_and_print_func( elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: preprocess_func = partial( - preprocess_packed_supervised_dataset, - tokenizer=tokenizer, - template=template, - data_args=data_args, + preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args ) elif processor is not None: preprocess_func = partial( - preprocess_multimodal_supervised_dataset, - processor=processor, - template=template, - data_args=data_args, + preprocess_multimodal_supervised_dataset, processor=processor, template=template, data_args=data_args ) else: preprocess_func = partial( - preprocess_supervised_dataset, - tokenizer=tokenizer, - template=template, - data_args=data_args, + preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args ) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( - preprocess_pairwise_dataset, - tokenizer=tokenizer, - template=template, - data_args=data_args, + preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args ) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) else: diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index cf21e932..f798ba5a 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from .formatter import SLOTS, Formatter - logger = get_logger(__name__) @@ -368,8 +367,7 @@ def get_template_and_fix_tokenizer( if stop_words: num_added_tokens = tokenizer.add_special_tokens( - dict(additional_special_tokens=stop_words), - replace_additional_special_tokens=False, + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False ) logger.info("Add {} to stop words.".format(",".join(stop_words))) if num_added_tokens > 0: @@ -393,7 +391,6 @@ _register_template( ), ) - _register_template( name="aquila", format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), @@ -406,36 +403,26 @@ _register_template( efficient_eos=True, ) - _register_template( name="atom", format_user=StringFormatter( - slots=[ - {"bos_token"}, - "Human: {{content}}\n", - {"eos_token"}, - {"bos_token"}, - "Assistant:", - ] + slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] ), format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), ) - _register_template( name="baichuan", format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), efficient_eos=True, ) - _register_template( name="baichuan2", format_user=StringFormatter(slots=["{{content}}"]), efficient_eos=True, ) - _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), @@ -444,13 +431,11 @@ _register_template( force_system=True, ) - _register_template( name="bluelm", format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), ) - _register_template( name="breeze", format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), @@ -462,7 +447,6 @@ _register_template( efficient_eos=True, ) - _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), @@ -472,7 +456,6 @@ _register_template( force_system=True, ) - _register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), @@ -480,40 +463,23 @@ _register_template( format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter( - slots=[ - {"token": "<|observation|>"}, - "\n", - "{{content}}", - {"token": "<|assistant|>"}, - ] + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, force_system=True, ) - _register_template( name="chatglm3_system", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_system=StringFormatter( - slots=[ - {"token": "[gMASK]"}, - {"token": "sop"}, - {"token": "<|system|>"}, - "\n", - "{{content}}", - ] + slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] ), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter( - slots=[ - {"token": "<|observation|>"}, - "\n", - "{{content}}", - {"token": "<|assistant|>"}, - ] + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), default_system=( "You are ChatGLM3, a large language model trained by Zhipu.AI. " @@ -523,7 +489,6 @@ _register_template( efficient_eos=True, ) - _register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -534,7 +499,6 @@ _register_template( replace_eos=True, ) - _register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -546,14 +510,12 @@ _register_template( replace_eos=True, ) - _register_template( name="codegeex2", format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), force_system=True, ) - _register_template( name="cohere", format_user=StringFormatter( @@ -568,7 +530,6 @@ _register_template( force_system=True, ) - _register_template( name="cpm", format_user=StringFormatter(slots=["<用户>{{content}}"]), @@ -576,7 +537,6 @@ _register_template( force_system=True, ) - _register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -602,7 +562,6 @@ _register_template( replace_eos=True, ) - _register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), @@ -610,7 +569,6 @@ _register_template( force_system=True, ) - _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), @@ -626,7 +584,6 @@ _register_template( efficient_eos=True, ) - _register_template( name="default", format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), @@ -634,14 +591,12 @@ _register_template( format_separator=EmptyFormatter(slots=["\n"]), ) - _register_template( name="empty", format_user=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]), ) - _register_template( name="falcon", format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), @@ -649,14 +604,12 @@ _register_template( efficient_eos=True, ) - _register_template( name="fewshot", format_separator=EmptyFormatter(slots=["\n\n"]), efficient_eos=True, ) - _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), @@ -669,7 +622,6 @@ _register_template( force_system=True, ) - _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), @@ -678,7 +630,6 @@ _register_template( efficient_eos=True, ) - _register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -695,7 +646,6 @@ _register_template( efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) - _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), @@ -712,7 +662,6 @@ _register_template( ), ) - _register_template( name="llama2_zh", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), @@ -720,7 +669,6 @@ _register_template( default_system="You are a helpful assistant. 你是一个乐于助人的助手。", ) - _register_template( name="llama3", format_user=StringFormatter( @@ -732,10 +680,7 @@ _register_template( ] ), format_system=StringFormatter( - slots=[ - {"bos_token"}, - "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>", - ] + slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] ), format_observation=StringFormatter( slots=[ @@ -750,7 +695,6 @@ _register_template( replace_eos=True, ) - _register_template( name="mistral", format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]), @@ -758,7 +702,6 @@ _register_template( force_system=True, ) - _register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), @@ -767,22 +710,14 @@ _register_template( force_system=True, ) - _register_template( name="openchat", - format_user=StringFormatter( - slots=[ - "GPT4 Correct User: {{content}}", - {"eos_token"}, - "GPT4 Correct Assistant:", - ] - ), + format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), force_system=True, ) - _register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), @@ -790,7 +725,6 @@ _register_template( force_system=True, ) - _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), @@ -802,7 +736,6 @@ _register_template( replace_eos=True, ) - _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -814,7 +747,6 @@ _register_template( replace_eos=True, ) - _register_template( name="solar", format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), @@ -822,7 +754,6 @@ _register_template( efficient_eos=True, ) - _register_template( name="starchat", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), @@ -833,7 +764,6 @@ _register_template( force_system=True, ) - _register_template( name="vicuna", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), @@ -843,7 +773,6 @@ _register_template( ), ) - _register_template( name="xuanyuan", format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), @@ -854,13 +783,11 @@ _register_template( ), ) - _register_template( name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), ) - _register_template( name="yayi", format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), @@ -880,7 +807,6 @@ _register_template( stop_words=["<|End|>"], ) - _register_template( name="yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -889,7 +815,6 @@ _register_template( replace_eos=True, ) - _register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), @@ -898,7 +823,6 @@ _register_template( replace_eos=True, ) - _register_template( name="zephyr", format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), @@ -907,7 +831,6 @@ _register_template( default_system="You are a friendly chatbot who always responds in the style of a pirate", ) - _register_template( name="ziya", format_user=StringFormatter(slots=[":{{content}}\n:"]), diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index cb525699..f4f71bc5 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -260,7 +260,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "sft_mm"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index dd7eb44c..5b5c0a4d 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -41,9 +41,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: } -def load_tokenizer( - model_args: "ModelArguments", -) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]: +def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]: r""" Loads pretrained tokenizer. diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 50833a99..205142e5 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -17,12 +17,7 @@ from .trainer import CustomSeq2SeqTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from ...hparams import ( - DataArguments, - FinetuningArguments, - GeneratingArguments, - ModelArguments, - ) + from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments def run_sft( @@ -36,14 +31,7 @@ def run_sft( tokenizer_modules = load_tokenizer(model_args) tokenizer = tokenizer_modules["tokenizer"] processor = tokenizer_modules["processor"] - dataset = get_dataset( - tokenizer, - model_args, - data_args, - training_args, - stage="sft", - processor=processor, - ) + dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if training_args.predict_with_generate: @@ -54,7 +42,7 @@ def run_sft( data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, - pad_to_multiple_of=(8 if tokenizer.padding_side == "right" else None), # for shift short attention + pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id), ) @@ -72,7 +60,7 @@ def run_sft( tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, - compute_metrics=(ComputeMetrics(tokenizer) if training_args.predict_with_generate else None), + compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, **split_dataset(dataset, data_args, training_args), )