modify style
This commit is contained in:
parent
3f4556454c
commit
771bed5bde
|
@ -8,7 +8,7 @@ from ...extras.logging import get_logger
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, LlavaConfig
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
@ -29,8 +29,10 @@ def autocast_projector_dtype(
|
|||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
if hasattr(model, mm_projector_name) and (getattr(model.config, "quantization_method", None)
|
||||
or "Yi" in getattr(model.config.text_config, "_name_or_path", None)):
|
||||
if hasattr(model, mm_projector_name) and (
|
||||
getattr(model.config, "quantization_method", None)
|
||||
or "Yi" in getattr(model.config.text_config, "_name_or_path", None)
|
||||
):
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Seq2SeqTrainer, ProcessorMixin
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainer
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
|
@ -127,4 +127,4 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
if self.processor is not None:
|
||||
if output_dir is None:
|
||||
output_dir = self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
|
Loading…
Reference in New Issue