robotics_transformer/film_efficientnet/pretrained_efficientnet_enc...

123 lines
4.6 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder based on Efficientnet."""
from typing import Optional
import gin
from robotics_transformer.film_efficientnet import film_conditioning_layer
from robotics_transformer.film_efficientnet import film_efficientnet_encoder
import tensorflow as tf
_MODELS = {
'b3': film_efficientnet_encoder.EfficientNetB3,
}
_SIZES = {
'b3': 300,
}
@gin.configurable
class EfficientNetEncoder(tf.keras.layers.Layer):
"""Applies a pretrained Efficientnet based encoder."""
def __init__(self,
model_variant: str = 'b3',
freeze: bool = False,
early_film: bool = True,
weights: Optional[str] = 'imagenet',
include_top: bool = False,
pooling: bool = True,
**kwargs):
"""Initialize the model.
Args:
model_variant: One of 'b0-b7' of the efficient encoders. See
https://arxiv.org/abs/1905.11946 to understand the variants.
freeze: Whether or not to freeze the pretrained weights (seems to not work
well).
early_film: Whether to inject film layers into the efficientnet encoder
(seems to be essential to getting strong performance).
weights: Which pretrained weights to use. Either 'imagenet', a path to the
pretrained weights, or None for from scratch.
include_top: Whether to add the top fully connected layer. If True, this
will cause encoding to fail and is used only for unit testing purposes.
pooling: If false, returns feature map before global average pooling
**kwargs: Keras specific layer kwargs.
"""
super(EfficientNetEncoder, self).__init__(**kwargs)
if model_variant not in _MODELS:
raise ValueError(f'Unknown variant {model_variant}')
self.model_variant = model_variant
self.early_film = early_film
self.freeze = freeze
self.conv1x1 = tf.keras.layers.Conv2D(
filters=512,
kernel_size=(1, 1),
strides=(1, 1),
padding='SAME',
use_bias=False,
kernel_initializer=tf.keras.initializers.VarianceScaling())
self.net = _MODELS[model_variant](
include_top=include_top,
weights=weights,
include_film=early_film,
)
self.film_layer = film_conditioning_layer.FilmConditioning(num_channels=512)
self._pooling = pooling
def _prepare_image(self, image: tf.Tensor) -> tf.Tensor:
"""Resize the input image and check that the range is correct."""
if len(image.shape) != 4 or image.shape[-1] != 3:
raise ValueError('Provided image should have shape (b, h, w, 3).')
size = _SIZES[self.model_variant]
if image.shape[1] < size / 4 or image.shape[2] < size / 4:
raise ValueError('Provided image is too small.')
if image.shape[1] > size * 4 or image.shape[2] > size * 4:
raise ValueError('Provided image is too large.')
image = tf.image.resize(image, (size, size))
c1 = tf.Assert(tf.reduce_max(image) <= 1, data=[tf.reduce_max(image)])
c2 = tf.Assert(tf.reduce_min(image) >= 0, data=[tf.reduce_min(image)])
with tf.control_dependencies([c1, c2]):
image *= 255 # The image is expected to be in range(0, 255).
image = film_efficientnet_encoder.preprocess_input(image)
return image
def _encode(self, image: tf.Tensor, context: tf.Tensor,
training: bool) -> tf.Tensor:
"""Run the image through the efficientnet encoder."""
image = self._prepare_image(image)
if self.early_film:
return self.net((image, context), training=training)
return self.net(image, training=training)
def call(self,
image: tf.Tensor,
context: Optional[tf.Tensor] = None,
training: bool = True) -> tf.Tensor:
if self.freeze:
features = tf.stop_gradient(self._encode(image, context, training))
else:
features = self._encode(image, context, training)
if context is not None:
features = self.conv1x1(features)
features = self.film_layer(features, context)
if not self._pooling:
return features
# Global average pool.
return tf.reduce_mean(features, [1, 2])