Update loader.py

This commit is contained in:
hoshi-hiyouga 2024-04-26 03:22:40 +08:00 committed by GitHub
parent f8c26e6a34
commit 7d812ed841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 33 deletions

View File

@ -1,12 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Union from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import ( from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
)
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -19,13 +13,19 @@ from .utils.unsloth import load_unsloth_pretrained_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r""" r"""
Gets arguments to load config/tokenizer/model. Gets arguments to load config/tokenizer/model.
@ -41,7 +41,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
} }
def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]: def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""
Loads pretrained tokenizer. Loads pretrained tokenizer.
@ -75,25 +75,14 @@ def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedT
logger.warning("New tokens have been added, changed `resize_vocab` to True.") logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
tokenizer_modules = {"tokenizer": tokenizer, "processor": None}
if model_args.use_mllm: if model_args.visual_inputs:
try: processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
processor = AutoProcessor.from_pretrained( setattr(processor, "tokenizer", tokenizer)
model_args.model_name_or_path, else:
use_fast=model_args.use_fast_tokenizer, processor = None
split_special_tokens=model_args.split_special_tokens,
padding_side="right", return {"tokenizer": tokenizer, "processor": processor}
**init_kwargs,
)
except Exception: # try the fast one
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
tokenizer_modules["processor"] = processor
return tokenizer_modules
def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
@ -132,11 +121,10 @@ def load_model(
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
else: else:
if model_args.use_mllm: model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)