diff --git a/src/llamafactory/model/utils/visual.py b/src/llamafactory/model/utils/visual.py index e3b81bd1..c8260b7f 100644 --- a/src/llamafactory/model/utils/visual.py +++ b/src/llamafactory/model/utils/visual.py @@ -69,7 +69,7 @@ def autocast_projector_dtype( ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None): + if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None): logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)