robotics_transformer/film_efficientnet/film_conditioning_layer.py

75 lines
2.5 KiB
Python
Raw Normal View History

2022-12-10 03:58:47 +08:00
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet variants model for Keras with Film-Conditioning.
Related papers/blogs:
- https://arxiv.org/abs/1512.03385
- https://arxiv.org/pdf/1603.05027v2.pdf
- http://torch.ch/blog/2016/02/04/resnets.html
- https://arxiv.org/abs/1709.07871
"""
import tensorflow.compat.v2 as tf
layers = tf.keras.layers
class FilmConditioning(tf.keras.layers.Layer):
"""Layer that adds FiLM conditioning.
This is intended to be applied after a convolutional layer. It will learn a
multiplicative and an additive factor to be applied to each channel of the
convolution's output.
Conv layer can be rank 2 or 4.
For further details, see: https://arxiv.org/abs/1709.07871
"""
def __init__(self, num_channels: int):
"""Constructs a FiLM conditioning layer.
Args:
num_channels: Number of filter channels to expect in the input.
"""
super().__init__()
# Note that we initialize with zeros because empirically we have found
# this works better than initializing with glorot.
self._projection_add = layers.Dense(
num_channels,
activation=None,
kernel_initializer='zeros',
bias_initializer='zeros')
self._projection_mult = layers.Dense(
num_channels,
activation=None,
kernel_initializer='zeros',
bias_initializer='zeros')
def call(self, conv_filters: tf.Tensor, conditioning: tf.Tensor):
tf.debugging.assert_rank(conditioning, 2)
projected_cond_add = self._projection_add(conditioning)
projected_cond_mult = self._projection_mult(conditioning)
if len(conv_filters.shape) == 4:
# [B, D] -> [B, 1, 1, D]
projected_cond_add = projected_cond_add[:, tf.newaxis, tf.newaxis]
projected_cond_mult = projected_cond_mult[:, tf.newaxis, tf.newaxis]
else:
tf.debugging.assert_rank(conv_filters, 2)
# Original FiLM paper argues that 1 + gamma centers the initialization at
# identity transform.
result = (1 + projected_cond_mult) * conv_filters + projected_cond_add
return result