75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
# Copyright 2022 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""ResNet variants model for Keras with Film-Conditioning.
|
|
|
|
Related papers/blogs:
|
|
- https://arxiv.org/abs/1512.03385
|
|
- https://arxiv.org/pdf/1603.05027v2.pdf
|
|
- http://torch.ch/blog/2016/02/04/resnets.html
|
|
- https://arxiv.org/abs/1709.07871
|
|
"""
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
layers = tf.keras.layers
|
|
|
|
|
|
class FilmConditioning(tf.keras.layers.Layer):
|
|
"""Layer that adds FiLM conditioning.
|
|
|
|
This is intended to be applied after a convolutional layer. It will learn a
|
|
multiplicative and an additive factor to be applied to each channel of the
|
|
convolution's output.
|
|
|
|
Conv layer can be rank 2 or 4.
|
|
|
|
For further details, see: https://arxiv.org/abs/1709.07871
|
|
"""
|
|
|
|
def __init__(self, num_channels: int):
|
|
"""Constructs a FiLM conditioning layer.
|
|
|
|
Args:
|
|
num_channels: Number of filter channels to expect in the input.
|
|
"""
|
|
super().__init__()
|
|
# Note that we initialize with zeros because empirically we have found
|
|
# this works better than initializing with glorot.
|
|
self._projection_add = layers.Dense(
|
|
num_channels,
|
|
activation=None,
|
|
kernel_initializer='zeros',
|
|
bias_initializer='zeros')
|
|
self._projection_mult = layers.Dense(
|
|
num_channels,
|
|
activation=None,
|
|
kernel_initializer='zeros',
|
|
bias_initializer='zeros')
|
|
|
|
def call(self, conv_filters: tf.Tensor, conditioning: tf.Tensor):
|
|
tf.debugging.assert_rank(conditioning, 2)
|
|
projected_cond_add = self._projection_add(conditioning)
|
|
projected_cond_mult = self._projection_mult(conditioning)
|
|
|
|
if len(conv_filters.shape) == 4:
|
|
# [B, D] -> [B, 1, 1, D]
|
|
projected_cond_add = projected_cond_add[:, tf.newaxis, tf.newaxis]
|
|
projected_cond_mult = projected_cond_mult[:, tf.newaxis, tf.newaxis]
|
|
else:
|
|
tf.debugging.assert_rank(conv_filters, 2)
|
|
|
|
# Original FiLM paper argues that 1 + gamma centers the initialization at
|
|
# identity transform.
|
|
result = (1 + projected_cond_mult) * conv_filters + projected_cond_add
|
|
return result
|