From d2bf69740043012a0025dd9d80c7adf979dc3a88 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Wed, 15 May 2024 11:22:15 +0800 Subject: [PATCH] cast dtype in mm_proj --- src/llmtuner/model/utils/visual.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index 0dc844f5..b8696096 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -8,7 +8,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, LlavaForConditionalGeneration from ...hparams import ModelArguments @@ -29,10 +29,7 @@ def autocast_projector_dtype( ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name) and ( - getattr(model.config, "quantization_method", None) - or "Yi" in getattr(model.config.text_config, "_name_or_path", None) - ): + if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None): logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) @@ -48,11 +45,13 @@ class LlavaMultiModalProjectorYiVL(nn.Module): self.act = nn.GELU() def forward(self, image_features): + dtype_ = self.linear_1.weight.dtype hidden_states = self.linear_1(image_features) hidden_states = self.linear_2(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_3(hidden_states) hidden_states = self.linear_4(hidden_states) + hidden_states = hidden_states.to(dtype_) return hidden_states