Update loader.py
This commit is contained in:
parent
f8c26e6a34
commit
7d812ed841
|
@ -1,12 +1,6 @@
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||||
|
|
||||||
from transformers import (
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
AutoProcessor,
|
|
||||||
AutoTokenizer,
|
|
||||||
)
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
@ -19,13 +13,19 @@ from .utils.unsloth import load_unsloth_pretrained_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerModule(TypedDict):
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
processor: Optional["ProcessorMixin"]
|
||||||
|
|
||||||
|
|
||||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||||
r"""
|
r"""
|
||||||
Gets arguments to load config/tokenizer/model.
|
Gets arguments to load config/tokenizer/model.
|
||||||
|
@ -41,7 +41,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]:
|
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained tokenizer.
|
Loads pretrained tokenizer.
|
||||||
|
|
||||||
|
@ -75,25 +75,14 @@ def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedT
|
||||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
tokenizer_modules = {"tokenizer": tokenizer, "processor": None}
|
|
||||||
if model_args.use_mllm:
|
if model_args.visual_inputs:
|
||||||
try:
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
processor = AutoProcessor.from_pretrained(
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
model_args.model_name_or_path,
|
else:
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
processor = None
|
||||||
split_special_tokens=model_args.split_special_tokens,
|
|
||||||
padding_side="right",
|
return {"tokenizer": tokenizer, "processor": processor}
|
||||||
**init_kwargs,
|
|
||||||
)
|
|
||||||
except Exception: # try the fast one
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
use_fast=True,
|
|
||||||
padding_side="right",
|
|
||||||
**init_kwargs,
|
|
||||||
)
|
|
||||||
tokenizer_modules["processor"] = processor
|
|
||||||
return tokenizer_modules
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||||
|
@ -132,11 +121,10 @@ def load_model(
|
||||||
|
|
||||||
if model_args.mixture_of_depths == "load":
|
if model_args.mixture_of_depths == "load":
|
||||||
model = load_mod_pretrained_model(**init_kwargs)
|
model = load_mod_pretrained_model(**init_kwargs)
|
||||||
|
elif model_args.visual_inputs:
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
||||||
else:
|
else:
|
||||||
if model_args.use_mllm:
|
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||||
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
|
||||||
|
|
||||||
if model_args.mixture_of_depths == "convert":
|
if model_args.mixture_of_depths == "convert":
|
||||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||||
|
|
Loading…
Reference in New Issue