diff --git a/data/dataset_info.json b/data/dataset_info.json index 6c7088f6..f9adf108 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -418,6 +418,17 @@ "hf_hub_url": "HuggingFaceH4/llava-instruct-mix-vsft" }, "mllm_instruct_example": { - "hf_hub_url": "data/mllm_example_dataset" + "file_name": "llava_instruct_example.json", + "formatting": "llava", + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } } -} \ No newline at end of file +} diff --git a/data/llava_instruct_example.json b/data/llava_instruct_example.json index 908b0695..b5c97387 100644 --- a/data/llava_instruct_example.json +++ b/data/llava_instruct_example.json @@ -2,7 +2,7 @@ { "messages": [ { - "content": "Who are they?", + "content": "Who are they?", "role": "user" }, { @@ -18,12 +18,14 @@ "role": "assistant" } ], - "image": "1.jpg" + "images": [ + "data/images/1.jpg" + ] }, { "messages": [ { - "content": "Who is he?", + "content": "Who is he?", "role": "user" }, { @@ -39,12 +41,14 @@ "role": "assistant" } ], - "image": "2.jpg" + "images": [ + "data/images/2.jpg" + ] }, { "messages": [ { - "content": "Please describe this image", + "content": "Please describe this image", "role": "user" }, { @@ -60,6 +64,8 @@ "role": "assistant" } ], - "image": "3.jpg" + "images": [ + "data/images/3.jpg" + ] } -] \ No newline at end of file +] diff --git a/examples/mllm/sft_instructblip.sh b/examples/mllm/sft_instructblip.sh deleted file mode 100644 index b3923655..00000000 --- a/examples/mllm/sft_instructblip.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ - --stage sft_mm \ - --do_train \ - --model_name_or_path Salesforce/instructblip-vicuna-7b \ - --dataset mllm_instruct_example \ - --dataset_dir data \ - --template default \ - --finetuning_type lora \ - --lora_target all \ - --output_dir saves/instructblip-vicuna-7b/lora/sft \ - --overwrite_cache \ - --overwrite_output_dir \ - --cutoff_len 1024 \ - --preprocessing_num_workers 16 \ - --per_device_train_batch_size 3 \ - --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --lr_scheduler_type cosine \ - --logging_steps 1 \ - --warmup_steps 20 \ - --save_steps 100 \ - --eval_steps 100 \ - --evaluation_strategy steps \ - --load_best_model_at_end \ - --learning_rate 1e-5 \ - --num_train_epochs 50 \ - --max_samples 3000 \ - --val_size 0.1 \ - --plot_loss \ - --bf16 \ No newline at end of file diff --git a/scripts/test_mllm.py b/scripts/test_mllm.py index 961f02bf..94d8670b 100644 --- a/scripts/test_mllm.py +++ b/scripts/test_mllm.py @@ -29,7 +29,10 @@ def get_processor(model_path): def apply_lora(base_model_path, model_path, lora_path): print(f"Loading the base model from {base_model_path}") base_model = AutoModelForVision2Seq.from_pretrained( - base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda", + base_model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map="cuda", ) processor = get_processor(base_model_path) tokenizer = processor.tokenizer @@ -60,11 +63,14 @@ def main( if not os.path.exists(model_path) or do_merge: apply_lora(base_model_path, model_path, lora_model_path) model = AutoModelForVision2Seq.from_pretrained( - model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda" + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map="cuda", ) processor = get_processor(model_path) raw_datasets = load_dataset(dataset_name) - train_dataset = raw_datasets['train'] + train_dataset = raw_datasets["train"] examples = train_dataset.select(range(3)) texts = [] images = [] @@ -81,5 +87,5 @@ def main( print(res) -if __name__ == '__main__': +if __name__ == "__main__": fire.Fire(main) diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py index 27a2f3b8..00a82d73 100644 --- a/src/llmtuner/data/__init__.py +++ b/src/llmtuner/data/__init__.py @@ -1,12 +1,11 @@ from .collator import PairwiseDataCollatorWithPadding -from .loader import get_dataset, get_mm_dataset +from .loader import get_dataset from .template import Template, get_template_and_fix_tokenizer, templates from .utils import Role, split_dataset __all__ = [ "PairwiseDataCollatorWithPadding", "get_dataset", - "get_mm_dataset", "Template", "get_template_and_fix_tokenizer", "templates", diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index 4de37e6d..85202ea8 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -13,7 +13,9 @@ if TYPE_CHECKING: from .parser import DatasetAttr -def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: +def convert_alpaca( + examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" +) -> Dict[str, List[Any]]: outputs = {"prompt": [], "response": [], "system": [], "tools": []} for i in range(len(examples[dataset_attr.prompt])): prompt = [] @@ -31,24 +33,38 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) - if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list): + if dataset_attr.response and isinstance( + examples[dataset_attr.response][i], list + ): response = [ - {"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i] + {"role": Role.ASSISTANT.value, "content": content} + for content in examples[dataset_attr.response][i] + ] + elif dataset_attr.response and isinstance( + examples[dataset_attr.response][i], str + ): + response = [ + { + "role": Role.ASSISTANT.value, + "content": examples[dataset_attr.response][i], + } ] - elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): - response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] else: response = [] outputs["prompt"].append(prompt) outputs["response"].append(response) - outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") + outputs["system"].append( + examples[dataset_attr.system][i] if dataset_attr.system else "" + ) outputs["tools"].append("") - + outputs["images"].append([]) return outputs -def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: +def convert_sharegpt( + examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" +) -> Dict[str, List[Any]]: outputs = {"prompt": [], "response": [], "system": [], "tools": []} tag_mapping = { dataset_attr.user_tag: Role.USER.value, @@ -61,7 +77,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) for i, messages in enumerate(examples[dataset_attr.messages]): - if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: + if ( + dataset_attr.system_tag + and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag + ): system = messages[0][dataset_attr.content_tag] messages = messages[1:] else: @@ -77,19 +96,81 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" raise ValueError("Invalid role tag in {}.".format(messages)) aligned_messages.append( - {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + { + "role": tag_mapping[message[dataset_attr.role_tag]], + "content": message[dataset_attr.content_tag], + } ) outputs["prompt"].append(aligned_messages[:-1]) outputs["response"].append(aligned_messages[-1:]) outputs["system"].append(system) - outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") + outputs["tools"].append( + examples[dataset_attr.tools][i] if dataset_attr.tools else "" + ) + outputs["images"].append([]) + + return outputs + + +def convert_llava( + examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" +) -> Dict[str, List[Any]]: + outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + tag_mapping = { + dataset_attr.user_tag: Role.USER.value, + dataset_attr.assistant_tag: Role.ASSISTANT.value, + dataset_attr.observation_tag: Role.OBSERVATION.value, + dataset_attr.function_tag: Role.FUNCTION.value, + dataset_attr.system_tag: Role.SYSTEM.value, + } + odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) + even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) + accept_tags = (odd_tags, even_tags) + for i, messages in enumerate(examples[dataset_attr.messages]): + if ( + dataset_attr.system_tag + and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag + ): + system = messages[0][dataset_attr.content_tag] + messages = messages[1:] + else: + system = examples[dataset_attr.system][i] if dataset_attr.system else "" + + messages = messages[: len(messages) // 2 * 2] # should be multiples of 2 + if len(messages) == 0: + continue + + aligned_messages = [] + for turn_idx, message in enumerate(messages): + if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + raise ValueError("Invalid role tag in {}.".format(messages)) + + aligned_messages.append( + { + "role": tag_mapping[message[dataset_attr.role_tag]], + "content": message[dataset_attr.content_tag], + } + ) + + outputs["prompt"].append(aligned_messages[:-1]) + outputs["response"].append(aligned_messages[-1:]) + outputs["system"].append(system) + outputs["tools"].append( + examples[dataset_attr.tools][i] if dataset_attr.tools else "" + ) + print(examples[dataset_attr.images][i]) + outputs["images"].append( + examples[dataset_attr.images][i] if dataset_attr.images else [] + ) return outputs def align_dataset( - dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", ) -> Union["Dataset", "IterableDataset"]: r""" Aligned dataset: @@ -100,6 +181,8 @@ def align_dataset( """ if dataset_attr.formatting == "alpaca": convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) + elif dataset_attr.formatting == "llava": + convert_func = partial(convert_llava, dataset_attr=dataset_attr) else: convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) @@ -107,13 +190,20 @@ def align_dataset( features = Features.from_dict( { "prompt": [ - {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + { + "role": {"dtype": "string", "_type": "Value"}, + "content": {"dtype": "string", "_type": "Value"}, + } ], "response": [ - {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + { + "role": {"dtype": "string", "_type": "Value"}, + "content": {"dtype": "string", "_type": "Value"}, + } ], "system": {"dtype": "string", "_type": "Value"}, "tools": {"dtype": "string", "_type": "Value"}, + "images": {"feature": {"_type": "Image"}, "_type": "Sequence"}, } ) kwargs = {} diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 18665731..c373e196 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,6 +1,6 @@ import inspect import os -from typing import TYPE_CHECKING, Literal, Union +from typing import TYPE_CHECKING, Literal, Union, Optional from datasets import load_dataset, load_from_disk @@ -25,9 +25,9 @@ logger = get_logger(__name__) def load_single_dataset( - dataset_attr: "DatasetAttr", - model_args: "ModelArguments", - data_args: "DataArguments", + dataset_attr: "DatasetAttr", + model_args: "ModelArguments", + data_args: "DataArguments", ) -> Union["Dataset", "IterableDataset"]: logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None @@ -78,14 +78,20 @@ def load_single_dataset( split=data_args.split, cache_dir=cache_dir, token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + use_streaming=( + data_args.streaming and (dataset_attr.load_from != "file") + ), ) if isinstance(dataset, MsDataset): dataset = dataset.to_hf_dataset() except ImportError: - raise ImportError("Please install modelscope via `pip install modelscope -U`") + raise ImportError( + "Please install modelscope via `pip install modelscope -U`" + ) else: - if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 + if ( + "trust_remote_code" in inspect.signature(load_dataset).parameters + ): # for datasets==2.16.0 kwargs = {"trust_remote_code": True} else: kwargs = {} @@ -102,7 +108,9 @@ def load_single_dataset( **kwargs, ) - if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True + if data_args.streaming and ( + dataset_attr.load_from == "file" + ): # faster than specifying streaming=True dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter if data_args.max_samples is not None: # truncate dataset @@ -113,11 +121,12 @@ def load_single_dataset( def get_dataset( - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], + processor: Optional["AutoProcessor"] = None, ) -> Union["Dataset", "IterableDataset"]: template = get_template_and_fix_tokenizer(tokenizer, data_args.template) if data_args.train_on_prompt and template.efficient_eos: @@ -126,9 +135,13 @@ def get_dataset( # Load tokenized dataset if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") + logger.warning( + "Loading dataset from disk will ignore other data arguments." + ) dataset = load_from_disk(data_args.tokenized_path) - logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) + logger.info( + "Loaded tokenized dataset from {}.".format(data_args.tokenized_path) + ) if data_args.streaming: dataset = dataset.to_iterable_dataset() return dataset @@ -139,15 +152,21 @@ def get_dataset( with training_args.main_process_first(desc="load dataset"): all_datasets = [] for dataset_attr in get_dataset_list(data_args): - if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): - raise ValueError("The dataset is not applicable in the current training stage.") + if (stage == "rm" and dataset_attr.ranking is False) or ( + stage != "rm" and dataset_attr.ranking is True + ): + raise ValueError( + "The dataset is not applicable in the current training stage." + ) - all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) + all_datasets.append( + load_single_dataset(dataset_attr, model_args, data_args) + ) dataset = merge_dataset(all_datasets, data_args, training_args) with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( - tokenizer, template, data_args, training_args, stage + tokenizer, template, data_args, training_args, stage, processor ) column_names = list(next(iter(dataset)).keys()) kwargs = {} @@ -158,13 +177,21 @@ def get_dataset( desc="Running tokenizer on dataset", ) - dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) + dataset = dataset.map( + preprocess_func, batched=True, remove_columns=column_names, **kwargs + ) if data_args.tokenized_path is not None: if training_args.should_save: dataset.save_to_disk(data_args.tokenized_path) - logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) - logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) + logger.info( + "Tokenized dataset saved at {}.".format(data_args.tokenized_path) + ) + logger.info( + "Please restart the training with `--tokenized_path {}`.".format( + data_args.tokenized_path + ) + ) exit(0) @@ -172,34 +199,8 @@ def get_dataset( try: print_function(next(iter(dataset))) except StopIteration: - raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") + raise RuntimeError( + "Cannot find valid samples, check `data/README.md` for the data format." + ) return dataset - - -def get_mm_dataset( - processor: "AutoProcessor", - model_args: "ModelArguments", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], -) -> Union["Dataset", "IterableDataset"]: - if data_args.tokenized_path is not None: - if has_tokenized_data(data_args.tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.tokenized_path) - logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset - - if data_args.streaming: - raise ValueError("Turn off `streaming` when saving dataset to disk.") - - with training_args.main_process_first(desc="load dataset"): - all_datasets = [] - for dataset_attr in get_dataset_list(data_args): - all_datasets.append(load_dataset(dataset_attr.dataset_name)['train']) - dataset = merge_dataset(all_datasets, data_args, training_args) - - return dataset diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index b9c8782a..79d6ed4e 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -25,7 +25,7 @@ class DatasetAttr: subset: Optional[str] = None folder: Optional[str] = None ranking: bool = False - formatting: Literal["alpaca", "sharegpt"] = "alpaca" + formatting: Literal["alpaca", "sharegpt", "llava"] = "alpaca" """ columns """ system: Optional[str] = None """ columns for the alpaca format """ @@ -44,11 +44,15 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" + """ columns for the mllm format """ + images: Optional[str] = None def __repr__(self) -> str: return self.dataset_name - def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: + def set_attr( + self, key: str, obj: Dict[str, Any], default: Optional[Any] = None + ) -> None: setattr(self, key, obj.get(key, default)) @@ -67,12 +71,16 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: except Exception as err: if len(dataset_names) != 0: raise ValueError( - "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) + "Cannot open {} due to {}.".format( + os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err) + ) ) dataset_info = None if data_args.interleave_probs is not None: - data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] + data_args.interleave_probs = [ + float(prob.strip()) for prob in data_args.interleave_probs.split(",") + ] dataset_list: List[DatasetAttr] = [] for name in dataset_names: @@ -90,31 +98,42 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: if has_hf_url or has_ms_url: if (use_modelscope() and has_ms_url) or (not has_hf_url): - dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) + dataset_attr = DatasetAttr( + "ms_hub", dataset_name=dataset_info[name]["ms_hub_url"] + ) else: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + dataset_attr = DatasetAttr( + "hf_hub", dataset_name=dataset_info[name]["hf_hub_url"] + ) elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + dataset_attr = DatasetAttr( + "script", dataset_name=dataset_info[name]["script_url"] + ) else: - dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + dataset_attr = DatasetAttr( + "file", dataset_name=dataset_info[name]["file_name"] + ) dataset_attr.set_attr("file_sha1", dataset_info[name]) dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("images", dataset_info[name], default="") if "columns" in dataset_info[name]: column_names = ["system"] if dataset_attr.formatting == "alpaca": column_names.extend(["prompt", "query", "response", "history"]) + elif dataset_attr.formatting == "llava": + column_names.extend(["messages", "images"]) else: column_names.extend(["messages", "tools"]) for column_name in column_names: dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) - if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: + if dataset_attr.formatting != "alpaca" and "tags" in dataset_info[name]: tag_names = ( "role_tag", "content_tag", diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 8494ba7e..dc72483f 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -1,6 +1,6 @@ from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple, Optional from ..extras.constants import IGNORE_INDEX from ..extras.logging import get_logger @@ -9,7 +9,7 @@ from .utils import Role if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers.tokenization_utils import PreTrainedTokenizer, AutoProcessor from ..hparams import DataArguments from .template import Template @@ -19,19 +19,27 @@ logger = get_logger(__name__) def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" + examples: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled - text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] + text_examples = [ + messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"] + ] if not data_args.packing: if data_args.template == "gemma": text_examples = [tokenizer.bos_token + example for example in text_examples] - result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) + result = tokenizer( + text_examples, add_special_tokens=False, max_length=data_args.cutoff_len + ) else: tokenized_examples = tokenizer(text_examples, add_special_tokens=False) - concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + concatenated_examples = { + k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys() + } total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) block_size = data_args.cutoff_len total_length = (total_length // block_size) * block_size @@ -54,7 +62,11 @@ def preprocess_supervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + model_inputs = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: @@ -75,7 +87,9 @@ def preprocess_supervised_dataset( if data_args.train_on_prompt: source_mask = source_ids elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * ( + len(source_ids) - 1 + ) else: source_mask = [IGNORE_INDEX] * len(source_ids) @@ -114,7 +128,9 @@ def preprocess_packed_supervised_dataset( if data_args.train_on_prompt: source_mask = source_ids elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * ( + len(source_ids) - 1 + ) else: source_mask = [IGNORE_INDEX] * len(source_ids) @@ -139,6 +155,64 @@ def preprocess_packed_supervised_dataset( return model_inputs +def preprocess_multimodal_supervised_dataset( + examples: Dict[str, List[Any]], + processor: "AutoProcessor", + template: "Template", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + tokenizer = processor.tokenizer + model_inputs = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "pixel_values": [], + } + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + continue + + messages = examples["prompt"][i] + examples["response"][i] + input_ids, labels = [], [] + for turn_idx, (source_ids, target_ids) in enumerate( + template.encode_multiturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * ( + len(source_ids) - 1 + ) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + pixel_values = processor.image_processor( + examples["images"][0], return_tensors="pt" + )["pixel_values"][0] + model_inputs["pixel_values"].append(pixel_values) + return model_inputs + + def preprocess_unsupervised_dataset( examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", @@ -155,7 +229,9 @@ def preprocess_unsupervised_dataset( if len(examples["response"][i]) == 1: messages = examples["prompt"][i] + examples["response"][i] else: - messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] + messages = examples["prompt"][i] + [ + {"role": Role.ASSISTANT.value, "content": ""} + ] input_ids, labels = template.encode_oneturn( tokenizer, @@ -218,29 +294,58 @@ def preprocess_pairwise_dataset( return model_inputs -def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: +def print_supervised_dataset_example( + example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer" +) -> None: print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print( + "inputs:\n{}".format( + tokenizer.decode(example["input_ids"], skip_special_tokens=False) + ) + ) print("label_ids:\n{}".format(example["labels"])) print( "labels:\n{}".format( - tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) + tokenizer.decode( + list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), + skip_special_tokens=False, + ) ) ) -def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: +def print_pairwise_dataset_example( + example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer" +) -> None: print("prompt_ids:\n{}".format(example["prompt_ids"])) - print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) + print( + "prompt:\n{}".format( + tokenizer.decode(example["prompt_ids"], skip_special_tokens=False) + ) + ) print("chosen_ids:\n{}".format(example["chosen_ids"])) - print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) + print( + "chosen:\n{}".format( + tokenizer.decode(example["chosen_ids"], skip_special_tokens=False) + ) + ) print("rejected_ids:\n{}".format(example["rejected_ids"])) - print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) + print( + "rejected:\n{}".format( + tokenizer.decode(example["rejected_ids"], skip_special_tokens=False) + ) + ) -def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: +def print_unsupervised_dataset_example( + example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer" +) -> None: print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print( + "inputs:\n{}".format( + tokenizer.decode(example["input_ids"], skip_special_tokens=False) + ) + ) def get_preprocess_and_print_func( @@ -249,30 +354,56 @@ def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"], + processor: Optional["AutoProcessor"] = None, ) -> Tuple[Callable, Callable]: if stage == "pt": - preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) - print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + preprocess_func = partial( + preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args + ) + print_function = partial( + print_unsupervised_dataset_example, tokenizer=tokenizer + ) elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: preprocess_func = partial( - preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_packed_supervised_dataset, + tokenizer=tokenizer, + template=template, + data_args=data_args, + ) + elif processor is not None: + preprocess_func = partial( + preprocess_multimodal_supervised_dataset, + processor=processor, + template=template, + data_args=data_args, ) else: preprocess_func = partial( - preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_supervised_dataset, + tokenizer=tokenizer, + template=template, + data_args=data_args, ) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( - preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_pairwise_dataset, + tokenizer=tokenizer, + template=template, + data_args=data_args, ) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) else: preprocess_func = partial( - preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_unsupervised_dataset, + tokenizer=tokenizer, + template=template, + data_args=data_args, + ) + print_function = partial( + print_unsupervised_dataset_example, tokenizer=tokenizer ) - print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - return preprocess_func, print_function \ No newline at end of file + return preprocess_func, print_function diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 73b22eb7..311660aa 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -42,7 +42,9 @@ class Template: r""" Returns a single pair of token ids representing prompt and response respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + encoded_pairs = self._encode( + tokenizer, messages, system, tools, cutoff_len, reserved_label_len + ) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids += query_ids + resp_ids @@ -62,7 +64,9 @@ class Template: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + return self._encode( + tokenizer, messages, system, tools, cutoff_len, reserved_label_len + ) def _encode( self, @@ -89,7 +93,9 @@ class Template: elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elements += self.format_user.apply( + content=message["content"], idx=str(i // 2) + ) elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION.value: @@ -104,7 +110,9 @@ class Template: return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) def _convert_elements_to_ids( - self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] + self, + tokenizer: "PreTrainedTokenizer", + elements: List[Union[str, Dict[str, str]]], ) -> List[int]: r""" Converts elements to token ids. @@ -122,7 +130,11 @@ class Template: elif "eos_token" in elem and tokenizer.eos_token_id is not None: token_ids += [tokenizer.eos_token_id] else: - raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + raise ValueError( + "Input must be string, set[str] or dict[str, str], got {}".format( + type(elem) + ) + ) return token_ids @@ -180,7 +192,9 @@ class Llama2Template(Template): elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=system_text + message["content"]) + elements += self.format_user.apply( + content=system_text + message["content"] + ) elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION.value: @@ -243,7 +257,9 @@ def _register_template( template_class = Llama2Template if name.startswith("llama2") else Template default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) - default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) + default_function_formatter = FunctionFormatter( + slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots + ) default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() templates[name] = template_class( @@ -279,7 +295,9 @@ def _jinja_escape(content: str) -> str: return content.replace("\n", r"\n").replace("'", r"\'") -def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: +def _convert_slots_to_jinja( + slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" +) -> str: slot_items = [] for slot in slots: if isinstance(slot, str): @@ -293,7 +311,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl elif isinstance(slot, set): if "bos_token" in slot: slot_items.append("'" + tokenizer.bos_token + "'") - elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced + elif ( + "eos_token" in slot + ): # do not use {{ eos_token }} since it may be replaced slot_items.append("'" + tokenizer.eos_token + "'") elif isinstance(slot, dict): raise ValueError("Dict is not supported.") @@ -305,25 +325,37 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") jinja_template = "" if template.default_system: - jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" + jinja_template += ( + "{% set system_message = '" + + _jinja_escape(template.default_system) + + "' %}" + ) jinja_template += ( - "{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}" + "{% if messages[0]['role'] == 'system' %}" + "{% set system_message = messages[0]['content'] %}" + "{% endif %}" ) - system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") + system_message = _convert_slots_to_jinja( + template.format_system.apply(), tokenizer, placeholder="system_message" + ) if isinstance(template, Llama2Template): pass elif template.force_system: jinja_template += "{{ " + system_message + " }}" else: - jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + jinja_template += ( + "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + ) jinja_template += "{% for message in messages %}" jinja_template += "{% set content = message['content'] %}" if isinstance(template, Llama2Template): jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" - jinja_template += "{% set content = " + system_message + " + message['content'] %}" + jinja_template += ( + "{% set content = " + system_message + " + message['content'] %}" + ) jinja_template += "{% endif %}" jinja_template += "{% if message['role'] == 'user' %}" user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) @@ -366,11 +398,14 @@ def get_template_and_fix_tokenizer( if stop_words: num_added_tokens = tokenizer.add_special_tokens( - dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + dict(additional_special_tokens=stop_words), + replace_additional_special_tokens=False, ) logger.info("Add {} to stop words.".format(",".join(stop_words))) if num_added_tokens > 0: - logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + logger.warning( + "New tokens have been added, make sure `resize_vocab` is True." + ) try: tokenizer.chat_template = _get_jinja_template(template, tokenizer) @@ -382,7 +417,9 @@ def get_template_and_fix_tokenizer( _register_template( name="alpaca", - format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_user=StringFormatter( + slots=["### Instruction:\n{{content}}\n\n### Response:\n"] + ), format_separator=EmptyFormatter(slots=["\n\n"]), default_system=( "Below is an instruction that describes a task. " @@ -407,7 +444,13 @@ _register_template( _register_template( name="atom", format_user=StringFormatter( - slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] + slots=[ + {"bos_token"}, + "Human: {{content}}\n", + {"eos_token"}, + {"bos_token"}, + "Assistant:", + ] ), format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), ) @@ -415,7 +458,9 @@ _register_template( _register_template( name="baichuan", - format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + format_user=StringFormatter( + slots=[{"token": ""}, "{{content}}", {"token": ""}] + ), efficient_eos=True, ) @@ -438,7 +483,9 @@ _register_template( _register_template( name="bluelm", - format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), + format_user=StringFormatter( + slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}] + ), ) @@ -457,7 +504,9 @@ _register_template( _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_system=StringFormatter( + slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"] + ), format_separator=EmptyFormatter(slots=["\n\n"]), efficient_eos=True, force_system=True, @@ -466,12 +515,21 @@ _register_template( _register_template( name="chatglm3", - format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_user=StringFormatter( + slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_system=StringFormatter( + slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"] + ), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter( - slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + slots=[ + {"token": "<|observation|>"}, + "\n", + "{{content}}", + {"token": "<|assistant|>"}, + ] ), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, @@ -481,14 +539,27 @@ _register_template( _register_template( name="chatglm3_system", - format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_user=StringFormatter( + slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_system=StringFormatter( - slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] + slots=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + {"token": "<|system|>"}, + "\n", + "{{content}}", + ] ), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter( - slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + slots=[ + {"token": "<|observation|>"}, + "\n", + "{{content}}", + {"token": "<|assistant|>"}, + ] ), default_system=( "You are ChatGLM3, a large language model trained by Zhipu.AI. " @@ -501,9 +572,15 @@ _register_template( _register_template( name="chatml", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_user=StringFormatter( + slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), + format_system=StringFormatter( + slots=["<|im_start|>system\n{{content}}<|im_end|>\n"] + ), + format_observation=StringFormatter( + slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>", "<|im_start|>"], replace_eos=True, @@ -512,9 +589,15 @@ _register_template( _register_template( name="chatml_de", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_user=StringFormatter( + slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), + format_system=StringFormatter( + slots=["<|im_start|>system\n{{content}}<|im_end|>\n"] + ), + format_observation=StringFormatter( + slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", stop_words=["<|im_end|>", "<|im_start|>"], @@ -524,7 +607,9 @@ _register_template( _register_template( name="codegeex2", - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_system=StringFormatter( + slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"] + ), force_system=True, ) @@ -554,9 +639,15 @@ _register_template( _register_template( name="dbrx", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_user=StringFormatter( + slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), + format_system=StringFormatter( + slots=["<|im_start|>system\n{{content}}<|im_end|>\n"] + ), + format_observation=StringFormatter( + slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), default_system=( "You are DBRX, created by Databricks. You were last updated in December 2023. " @@ -634,7 +725,9 @@ _register_template( _register_template( name="gemma", - format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_user=StringFormatter( + slots=["user\n{{content}}\nmodel\n"] + ), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_observation=StringFormatter( slots=["tool\n{{content}}\nmodel\n"] @@ -647,7 +740,9 @@ _register_template( _register_template( name="intern", - format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), + format_user=StringFormatter( + slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"] + ), format_separator=EmptyFormatter(slots=[{"token": ""}, "\n"]), stop_words=[""], efficient_eos=True, @@ -656,8 +751,12 @@ _register_template( _register_template( name="intern2", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_user=StringFormatter( + slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), + format_system=StringFormatter( + slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), default_system=( "You are an AI assistant whose name is InternLM (书生·浦语).\n" @@ -707,7 +806,10 @@ _register_template( ] ), format_system=StringFormatter( - slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] + slots=[ + {"bos_token"}, + "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>", + ] ), format_observation=StringFormatter( slots=[ @@ -742,7 +844,13 @@ _register_template( _register_template( name="openchat", - format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), + format_user=StringFormatter( + slots=[ + "GPT4 Correct User: {{content}}", + {"eos_token"}, + "GPT4 Correct Assistant:", + ] + ), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), force_system=True, @@ -751,7 +859,9 @@ _register_template( _register_template( name="orion", - format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), + format_user=StringFormatter( + slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}] + ), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), force_system=True, ) @@ -759,9 +869,15 @@ _register_template( _register_template( name="phi", - format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), - format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), - format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]), + format_user=StringFormatter( + slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"] + ), + format_system=StringFormatter( + slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"] + ), + format_observation=StringFormatter( + slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful AI assistant.", stop_words=["<|end|>"], @@ -771,9 +887,15 @@ _register_template( _register_template( name="qwen", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_user=StringFormatter( + slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), + format_system=StringFormatter( + slots=["<|im_start|>system\n{{content}}<|im_end|>\n"] + ), + format_observation=StringFormatter( + slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], @@ -829,8 +951,12 @@ _register_template( _register_template( name="yayi", - format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), - format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), + format_user=StringFormatter( + slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"] + ), + format_system=StringFormatter( + slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"] + ), format_separator=EmptyFormatter(slots=["\n\n"]), default_system=( "You are a helpful, respectful and honest assistant named YaYi " @@ -849,7 +975,9 @@ _register_template( _register_template( name="yi", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_user=StringFormatter( + slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], replace_eos=True, @@ -867,7 +995,9 @@ _register_template( _register_template( name="zephyr", - format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), + format_user=StringFormatter( + slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"] + ), format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), default_system="You are a friendly chatbot who always responds in the style of a pirate", @@ -879,3 +1009,13 @@ _register_template( format_user=StringFormatter(slots=[":{{content}}\n:"]), format_separator=EmptyFormatter(slots=["\n"]), ) + +_register_template( + name="llava", + format_user=StringFormatter(slots=["USER: {{content}} "]), + format_assistant=StringFormatter(slots=["ASSISTANT: {{content}}"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), +) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index a6e4b710..63fc7f02 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -15,23 +15,33 @@ class ModelArguments: ) adapter_name_or_path: Optional[str] = field( default=None, - metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, + metadata={ + "help": "Path to the adapter weight or identifier from huggingface.co/models." + }, ) cache_dir: Optional[str] = field( default=None, - metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, + metadata={ + "help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn." + }, ) use_fast_tokenizer: bool = field( default=True, - metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, + metadata={ + "help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)." + }, ) resize_vocab: bool = field( default=False, - metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, + metadata={ + "help": "Whether or not to resize the tokenizer vocab and the embedding layers." + }, ) split_special_tokens: bool = field( default=False, - metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, + metadata={ + "help": "Whether or not the special tokens should be split during the tokenization process." + }, ) new_special_tokens: Optional[str] = field( default=None, @@ -39,7 +49,9 @@ class ModelArguments: ) model_revision: str = field( default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, ) low_cpu_mem_usage: bool = field( default=True, @@ -47,7 +59,9 @@ class ModelArguments: ) quantization_bit: Optional[int] = field( default=None, - metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, + metadata={ + "help": "The number of bits to quantize the model using bitsandbytes." + }, ) quantization_type: Literal["fp4", "nf4"] = field( default="nf4", @@ -55,15 +69,21 @@ class ModelArguments: ) double_quantization: bool = field( default=True, - metadata={"help": "Whether or not to use double quantization in int4 training."}, + metadata={ + "help": "Whether or not to use double quantization in int4 training." + }, ) quantization_device_map: Optional[Literal["auto"]] = field( default=None, - metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, + metadata={ + "help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0." + }, ) rope_scaling: Optional[Literal["linear", "dynamic"]] = field( default=None, - metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, + metadata={ + "help": "Which scaling strategy should be adopted for the RoPE embeddings." + }, ) flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( default="auto", @@ -71,19 +91,27 @@ class ModelArguments: ) shift_attn: bool = field( default=False, - metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, + metadata={ + "help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA." + }, ) mixture_of_depths: Optional[Literal["convert", "load"]] = field( default=None, - metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, + metadata={ + "help": "Convert the model to mixture-of-depths (MoD) or load the MoD model." + }, ) use_unsloth: bool = field( default=False, - metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, + metadata={ + "help": "Whether or not to use unsloth's optimization for the LoRA training." + }, ) moe_aux_loss_coef: Optional[float] = field( default=None, - metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + metadata={ + "help": "Coefficient of the auxiliary router loss in mixture-of-experts model." + }, ) disable_gradient_checkpointing: bool = field( default=False, @@ -107,7 +135,9 @@ class ModelArguments: ) vllm_gpu_util: float = field( default=0.9, - metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, + metadata={ + "help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine." + }, ) vllm_enforce_eager: bool = field( default=False, @@ -147,7 +177,9 @@ class ModelArguments: ) export_quantization_dataset: Optional[str] = field( default=None, - metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, + metadata={ + "help": "Path to the dataset or dataset name to use in quantizing the exported model." + }, ) export_quantization_nsamples: int = field( default=128, @@ -155,19 +187,27 @@ class ModelArguments: ) export_quantization_maxlen: int = field( default=1024, - metadata={"help": "The maximum length of the model inputs used for quantization."}, + metadata={ + "help": "The maximum length of the model inputs used for quantization." + }, ) export_legacy_format: bool = field( default=False, - metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, + metadata={ + "help": "Whether or not to save the `.bin` files instead of `.safetensors`." + }, ) export_hub_model_id: Optional[str] = field( default=None, - metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, + metadata={ + "help": "The name of the repository if push the model to the Hugging Face hub." + }, ) print_param_status: bool = field( default=False, - metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, + metadata={ + "help": "For debugging purposes, print the status of the parameters in the model." + }, ) use_mllm: bool = field( default=False, @@ -180,18 +220,39 @@ class ModelArguments: self.model_max_length = None if self.split_special_tokens and self.use_fast_tokenizer: - raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + raise ValueError( + "`split_special_tokens` is only supported for slow tokenizers." + ) - if self.adapter_name_or_path is not None: # support merging multiple lora weights - self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + if ( + self.adapter_name_or_path is not None + ): # support merging multiple lora weights + self.adapter_name_or_path = [ + path.strip() for path in self.adapter_name_or_path.split(",") + ] if self.new_special_tokens is not None: # support multiple special tokens - self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + self.new_special_tokens = [ + token.strip() for token in self.new_special_tokens.split(",") + ] - assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." + assert self.quantization_bit in [ + None, + 8, + 4, + ], "We only accept 4-bit or 8-bit quantization." + assert self.export_quantization_bit in [ + None, + 8, + 4, + 3, + 2, + ], "We only accept 2/3/4/8-bit quantization." - if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + if ( + self.export_quantization_bit is not None + and self.export_quantization_dataset is None + ): raise ValueError("Quantization dataset is necessary for exporting.") def to_dict(self) -> Dict[str, Any]: diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index bcefee92..e65798b7 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -11,7 +11,7 @@ from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel, AutoModelForVision2Seq + from transformers import PretrainedConfig, PreTrainedModel from ..hparams import FinetuningArguments, ModelArguments @@ -21,11 +21,11 @@ logger = get_logger(__name__) def init_adapter( config: "PretrainedConfig", - model: Union["PreTrainedModel","AutoModelForVision2Seq"], + model: Union["PreTrainedModel"], model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool, -) -> Union["PreTrainedModel","AutoModelForVision2Seq"]: +) -> Union["PreTrainedModel"]: r""" Initializes the adapters. @@ -38,7 +38,9 @@ def init_adapter( logger.info("Adapter is not found at evaluation, load the base model.") return model - if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): + if finetuning_args.finetuning_type != "lora" and getattr( + model, "quantization_method", None + ): raise ValueError("You can only use lora for quantized models.") if finetuning_args.finetuning_type == "full" and is_trainable: @@ -49,9 +51,9 @@ def init_adapter( if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") num_layers = ( - getattr(model.config, "num_hidden_layers", None) - or getattr(model.config, "num_layers", None) - or getattr(model.config, "n_layer", None) + getattr(model.config, "num_hidden_layers", None) + or getattr(model.config, "num_layers", None) + or getattr(model.config, "n_layer", None) ) if not num_layers: raise ValueError("Current model does not support freeze tuning.") @@ -66,8 +68,12 @@ def init_adapter( stride = num_layers // finetuning_args.num_layer_trainable trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) - elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 - trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers) + elif ( + finetuning_args.num_layer_trainable > 0 + ): # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = range( + num_layers - finetuning_args.num_layer_trainable, num_layers + ) else: # fine-tuning the first n layers if num_layer_trainable < 0 trainable_layer_ids = range(-finetuning_args.num_layer_trainable) @@ -82,11 +88,15 @@ def init_adapter( for module_name in finetuning_args.name_module_trainable: if module_name not in freeze_modules: raise ValueError( - "Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules)) + "Module {} is not found, please choose from {}".format( + module_name, ", ".join(freeze_modules) + ) ) for idx in trainable_layer_ids: - trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) + trainable_layers.append( + ".{:d}.{}".format(idx, module_name if module_name != "all" else "") + ) for name, param in model.named_parameters(): if any(trainable_layer in name for trainable_layer in trainable_layers): @@ -95,27 +105,43 @@ def init_adapter( else: param.requires_grad_(False) - logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) + logger.info( + "Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))) + ) if finetuning_args.finetuning_type == "lora": - logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + logger.info( + "Fine-tuning method: {}".format( + "DoRA" if finetuning_args.use_dora else "LoRA" + ) + ) adapter_to_resume = None if model_args.adapter_name_or_path is not None: is_mergeable = True - if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable - assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." + if getattr( + model, "quantization_method", None + ): # merge lora in quantized model is unstable + assert ( + len(model_args.adapter_name_or_path) == 1 + ), "Quantized model only accepts a single adapter." is_mergeable = False if is_deepspeed_zero3_enabled(): - assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." + assert ( + len(model_args.adapter_name_or_path) == 1 + ), "Cannot use multiple adapters in DeepSpeed ZeRO-3." is_mergeable = False if model_args.use_unsloth: - assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." + assert ( + len(model_args.adapter_name_or_path) == 1 + ), "Unsloth model only accepts a single adapter." is_mergeable = False - if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): + if (is_trainable and not finetuning_args.create_new_adapter) or ( + not is_mergeable + ): adapter_to_merge = model_args.adapter_name_or_path[:-1] adapter_to_resume = model_args.adapter_name_or_path[-1] else: @@ -132,7 +158,9 @@ def init_adapter( if adapter_to_resume is not None: # resume lora training if model_args.use_unsloth: - model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + model = load_unsloth_peft_model( + config, model_args, is_trainable=is_trainable + ) else: model = PeftModel.from_pretrained( model, @@ -141,19 +169,27 @@ def init_adapter( offload_folder=model_args.offload_folder, ) - if is_trainable and adapter_to_resume is None: # create new lora weights while training - if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + if ( + is_trainable and adapter_to_resume is None + ): # create new lora weights while training + if ( + len(finetuning_args.lora_target) == 1 + and finetuning_args.lora_target[0] == "all" + ): target_modules = find_all_linear_modules(model) else: target_modules = finetuning_args.lora_target if finetuning_args.use_llama_pro: - target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) + target_modules = find_expanded_modules( + model, target_modules, finetuning_args.num_layer_trainable + ) if ( - finetuning_args.use_dora - and getattr(model, "quantization_method", None) is not None - and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) + != QuantizationMethod.BITS_AND_BYTES ): raise ValueError("DoRA is not compatible with PTQ-quantized models.") @@ -166,7 +202,11 @@ def init_adapter( module_names.add(name.split(".")[-1]) finetuning_args.additional_target = module_names - logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + logger.warning( + "Vocab has been resized, add {} to trainable params.".format( + ",".join(module_names) + ) + ) peft_kwargs = { "r": finetuning_args.lora_rank, @@ -193,6 +233,10 @@ def init_adapter( param.data = param.data.to(torch.float32) if model_args.adapter_name_or_path is not None: - logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + logger.info( + "Loaded adapter(s): {}".format( + ",".join(model_args.adapter_name_or_path) + ) + ) - return model \ No newline at end of file + return model diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 3712a592..18b0cf79 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,6 +1,12 @@ from typing import TYPE_CHECKING, Any, Dict, Union -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + AutoProcessor, + AutoModelForVision2Seq, +) from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger @@ -62,10 +68,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": dict(additional_special_tokens=model_args.new_special_tokens), replace_additional_special_tokens=False, ) - logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) + logger.info( + "Add {} to special tokens.".format(",".join(model_args.new_special_tokens)) + ) if num_added_tokens > 0 and not model_args.resize_vocab: model_args.resize_vocab = True - logger.warning("New tokens have been added, changed `resize_vocab` to True.") + logger.warning( + "New tokens have been added, changed `resize_vocab` to True." + ) patch_tokenizer(tokenizer) return tokenizer @@ -111,7 +121,7 @@ def load_model( finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, -) -> Union["PreTrainedModel", "AutoModelForVision2Seq"]: +) -> Union["PreTrainedModel"]: r""" Loads pretrained model. """ @@ -170,8 +180,10 @@ def load_model( trainable_params, all_param = count_parameters(model) if is_trainable: - param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( - trainable_params, all_param, 100 * trainable_params / all_param + param_stats = ( + "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) ) else: param_stats = "all params: {:d}".format(all_param) @@ -185,4 +197,4 @@ def load_model( ) ) - return model \ No newline at end of file + return model diff --git a/src/llmtuner/train/sftmm/collator.py b/src/llmtuner/train/sftmm/collator.py index 95dbd939..2931dd9c 100644 --- a/src/llmtuner/train/sftmm/collator.py +++ b/src/llmtuner/train/sftmm/collator.py @@ -19,7 +19,9 @@ class DataCollatorForVis2Seq: texts.append(text) images.append(example["images"][0]) - batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True) + batch = self.processor( + text=texts, images=images, return_tensors="pt", padding=True + ) labels = batch["input_ids"].clone() if self.processor.tokenizer.pad_token_id is not None: @@ -27,3 +29,14 @@ class DataCollatorForVis2Seq: batch["labels"] = labels return batch + + +@dataclass +class DataCollatorForMLLM: + processor: AutoProcessor + + def __call__(self, examples): + print(examples[0].keys()) + print(examples[0]["input_ids"]) + batch = {} + return batch diff --git a/src/llmtuner/train/sftmm/workflow.py b/src/llmtuner/train/sftmm/workflow.py index 7afd8f6f..3849a563 100644 --- a/src/llmtuner/train/sftmm/workflow.py +++ b/src/llmtuner/train/sftmm/workflow.py @@ -1,47 +1,66 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py import os from typing import TYPE_CHECKING, List, Optional -from ...data import split_dataset, get_mm_dataset +from ...data import get_dataset from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss -from ...model import load_tokenizer, load_processor, load_model +from ...model import load_processor, load_model from ..utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer -from .collator import DataCollatorForVis2Seq +from transformers import DataCollatorForSeq2Seq +from ...extras.constants import IGNORE_INDEX if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + from ...hparams import ( + DataArguments, + FinetuningArguments, + GeneratingArguments, + ModelArguments, + ) def run_sft_mm( - model_args: "ModelArguments", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - finetuning_args: "FinetuningArguments", - generating_args: "GeneratingArguments", - callbacks: Optional[List["TrainerCallback"]] = None, + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, ): processor = load_processor(model_args) - tokenizer = load_tokenizer(model_args) - CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" - tokenizer.chat_template = CHAT_TEMPLATE - processor.tokenizer = tokenizer - model = load_model(processor.tokenizer, model_args, finetuning_args, training_args.do_train) - dataset = get_mm_dataset(processor, model_args, data_args, training_args, stage="sft") + tokenizer = processor.tokenizer + dataset = get_dataset( + tokenizer, model_args, data_args, training_args, "sft", processor + ) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if getattr(model, "is_quantized", False) and not training_args.do_train: - setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction + setattr( + model, "_hf_peft_config_loaded", True + ) # hack here: make model compatible with prediction train_dataset = dataset eval_dataset = dataset - data_collator = DataCollatorForVis2Seq( - processor=processor, + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + pad_to_multiple_of=( + 8 if tokenizer.padding_side == "right" else None + ), # for shift short attention + label_pad_token_id=( + IGNORE_INDEX + if data_args.ignore_pad_token_for_loss + else tokenizer.pad_token_id + ), ) # Override the decoding parameters of Seq2SeqTrainer - training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len - training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.generation_max_length = ( + training_args.generation_max_length or data_args.cutoff_len + ) + training_args.generation_num_beams = ( + data_args.eval_num_beams or training_args.generation_num_beams + ) training_args.remove_unused_columns = False # Initialize our Trainer @@ -52,19 +71,26 @@ def run_sft_mm( tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, - compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + compute_metrics=( + ComputeMetrics(tokenizer) if training_args.predict_with_generate else None + ), train_dataset=train_dataset, eval_dataset=eval_dataset, ) + # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() - gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["eos_token_id"] = [ + tokenizer.eos_token_id + ] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() # Training if training_args.do_train: - train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + train_result = trainer.train( + resume_from_checkpoint=training_args.resume_from_checkpoint + ) trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) @@ -75,19 +101,27 @@ def run_sft_mm( # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) - if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + if ( + training_args.predict_with_generate + ): # eval_loss will be wrong if predict_with_generate is enabled metrics.pop("eval_loss", None) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Predict if training_args.do_predict: - predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) - if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + predict_results = trainer.predict( + dataset, metric_key_prefix="predict", **gen_kwargs + ) + if ( + training_args.predict_with_generate + ): # predict_loss will be wrong if predict_with_generate is enabled predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(predict_results) # Create model card - create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) + create_modelcard_and_push( + trainer, model_args, data_args, training_args, finetuning_args + )