Merge pull request #3450 from BUAADreamer/mllm
Add Multimodal LLM Finetuning
This commit is contained in:
commit
c20f750d11
|
@ -58,6 +58,21 @@
|
|||
"tools": "tools"
|
||||
}
|
||||
},
|
||||
"mllm_demo": {
|
||||
"file_name": "mllm_demo.json",
|
||||
"file_sha1": "b6709b23657d5c42a701f1c5574f3a6edaa40a20",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages",
|
||||
"images": "images"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"example": {
|
||||
"script_url": "example_dataset",
|
||||
"columns": {
|
||||
|
@ -185,6 +200,7 @@
|
|||
"ultrachat_200k": {
|
||||
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
||||
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
},
|
||||
|
@ -193,8 +209,7 @@
|
|||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
},
|
||||
"formatting": "sharegpt"
|
||||
}
|
||||
},
|
||||
"agent_instruct": {
|
||||
"hf_hub_url": "THUDM/AgentInstruct",
|
||||
|
@ -204,6 +219,7 @@
|
|||
"lmsys_chat": {
|
||||
"hf_hub_url": "lmsys/lmsys-chat-1m",
|
||||
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversation"
|
||||
},
|
||||
|
@ -212,8 +228,7 @@
|
|||
"content_tag": "content",
|
||||
"user_tag": "human",
|
||||
"assistant_tag": "assistant"
|
||||
},
|
||||
"formatting": "sharegpt"
|
||||
}
|
||||
},
|
||||
"evol_instruct": {
|
||||
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
|
||||
|
@ -340,7 +355,7 @@
|
|||
"history": "history"
|
||||
}
|
||||
},
|
||||
"orca_dpo_de" : {
|
||||
"orca_dpo_de": {
|
||||
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
|
||||
"ranking": true
|
||||
},
|
||||
|
@ -414,4 +429,4 @@
|
|||
},
|
||||
"folder": "python"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
|
@ -0,0 +1,71 @@
|
|||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Who are they?<image>",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "They're Kane and Gretzka from Bayern Munich.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What are they doing?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "They are celebrating on the soccer field",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"images/1.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Who is he?<image>",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "He's Thomas Muller from Bayern Munich.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "Why is he on the ground?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Because he's sliding on his knees to celebrate.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"images/2.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Please describe this image<image>",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Chinese astronaut Gui Haichao is giving a speech.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What has he accomplished?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"images/3.jpg"
|
||||
]
|
||||
}
|
||||
]
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft_mm \
|
||||
--do_train \
|
||||
--model_name_or_path llava-hf/llava-1.5-7b-hf \
|
||||
--dataset mllm_instruct_example \
|
||||
--dataset_dir data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target all \
|
||||
--output_dir saves/llava-1.5-7b/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 1 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 100 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--bf16
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
|
@ -13,8 +14,10 @@ if TYPE_CHECKING:
|
|||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
|
@ -44,12 +47,19 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["images"].append(
|
||||
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
|
||||
if dataset_attr.images
|
||||
else []
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
|
@ -84,6 +94,11 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
|||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(
|
||||
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
|
||||
if dataset_attr.images
|
||||
else []
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -96,12 +111,13 @@ def align_dataset(
|
|||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
system: "..."
|
||||
tools: "..."
|
||||
tools: "...",
|
||||
images: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
else:
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
features = Features.from_dict(
|
||||
|
@ -114,6 +130,7 @@ def align_dataset(
|
|||
],
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
|
@ -16,7 +16,7 @@ from .utils import checksum, merge_dataset
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
|
@ -115,11 +115,12 @@ def load_single_dataset(
|
|||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
|
@ -149,7 +150,7 @@ def get_dataset(
|
|||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
tokenizer, template, data_args, training_args, stage
|
||||
data_args, training_args, stage, template, tokenizer, processor
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
|
|
|
@ -28,6 +28,7 @@ class DatasetAttr:
|
|||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
system: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
""" columns for the alpaca format """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
|
@ -105,7 +106,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
column_names = ["system", "images"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
|
@ -8,7 +8,9 @@ from .utils import Role
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from PIL import Image
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
|
@ -18,6 +20,14 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _preprocess_visual_inputs(model_inputs: Dict[str, Any], processor: "ProcessorMixin", image: "Image") -> None:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"][0]
|
||||
if "pixel_values" not in model_inputs:
|
||||
model_inputs["pixel_values"] = []
|
||||
model_inputs["pixel_values"].append(pixel_values)
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
|
@ -48,8 +58,9 @@ def preprocess_pretrain_dataset(
|
|||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
|
@ -89,14 +100,16 @@ def preprocess_supervised_dataset(
|
|||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None and "images" in examples:
|
||||
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
|
@ -141,8 +154,9 @@ def preprocess_packed_supervised_dataset(
|
|||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
|
@ -172,14 +186,17 @@ def preprocess_unsupervised_dataset(
|
|||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None and "images" in examples:
|
||||
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
|
@ -214,6 +231,8 @@ def preprocess_pairwise_dataset(
|
|||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
if processor is not None and "images" in examples:
|
||||
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -244,11 +263,12 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
|||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
|
@ -256,22 +276,37 @@ def get_preprocess_and_print_func(
|
|||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_pairwise_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
preprocess_unsupervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
|
|
|
@ -81,6 +81,10 @@ class ModelArguments:
|
|||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
visual_inputs: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
@ -13,7 +13,7 @@ from .utils.unsloth import load_unsloth_pretrained_model
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
@ -21,6 +21,11 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizerModule(TypedDict):
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
processor: Optional["ProcessorMixin"]
|
||||
|
||||
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
r"""
|
||||
Gets arguments to load config/tokenizer/model.
|
||||
|
@ -36,7 +41,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer.
|
||||
|
||||
|
@ -70,7 +75,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
|||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
return tokenizer
|
||||
|
||||
if model_args.visual_inputs:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
else:
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
|
@ -109,6 +121,8 @@ def load_model(
|
|||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
elif model_args.visual_inputs:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
|
|
|
@ -28,9 +28,10 @@ def run_sft(
|
|||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
@ -47,6 +48,7 @@ def run_sft(
|
|||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
|
|
Loading…
Reference in New Issue