robotics_transformer/film_efficientnet/preprocessors.py

109 lines
4.2 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing functions for transforming the image for training."""
from typing import Optional
import gin
import tensorflow.compat.v2 as tf
CROP_SIZE = 472
@gin.configurable(
denylist=['images', 'crop_size', 'training', 'convert_dtype', 'seed'])
def convert_dtype_and_crop_images(images,
crop_size: int = CROP_SIZE,
training: bool = True,
pad_then_crop: bool = False,
convert_dtype: bool = True,
seed: Optional[tf.Tensor] = None):
"""Convert uint8 [512, 640, 3] images to float32 and square crop.
Args:
images: [B, H, W, 3] uint8 tensor of images.
crop_size: Width of the square crop.
training: If we are in training (random crop) or not-training (fixed crop).
pad_then_crop: If True, pads image and then crops the original image size.
This allows full field of view to be extracted.
convert_dtype: whether or not to convert the image to float32 in the range
of (0, 1).
seed: Optional seed of shape (2,) for giving to tf.random.stateless_uniform
Returns:
[B, crop_size, crop_size, 3] images of dtype float32.
"""
if seed is None:
seed = tf.random.uniform(shape=(2,), maxval=2**30, dtype=tf.int32)
seed2 = tf.random.experimental.stateless_split(seed, num=1)[0]
if convert_dtype:
images = tf.image.convert_image_dtype(images, tf.float32)
image_height = images.get_shape().as_list()[-3]
image_width = images.get_shape().as_list()[-2]
if pad_then_crop:
if training:
if image_height == 512:
ud_pad = 40
lr_pad = 100
elif image_height == 256:
ud_pad = 20
lr_pad = 50
else:
raise ValueError(
'convert_dtype_and_crop_images only supports image height 512 or '
'256.')
max_y = 2 * ud_pad
max_x = 2 * lr_pad
images = tf.image.pad_to_bounding_box(
images,
offset_height=ud_pad,
offset_width=lr_pad,
target_height=image_height + 2 * ud_pad,
target_width=image_width + 2 * lr_pad)
offset_y = tf.random.stateless_uniform((),
maxval=max_y + 1,
dtype=tf.int32,
seed=seed)
offset_x = tf.random.stateless_uniform((),
maxval=max_x + 1,
dtype=tf.int32,
seed=seed2)
images = tf.image.crop_to_bounding_box(images, offset_y, offset_x,
image_height, image_width)
else:
# Standard cropping.
max_y = image_height - crop_size
max_x = image_width - crop_size
if training:
offset_y = tf.random.stateless_uniform((),
maxval=max_y + 1,
dtype=tf.int32,
seed=seed)
offset_x = tf.random.stateless_uniform((),
maxval=max_x + 1,
dtype=tf.int32,
seed=seed2)
images = tf.image.crop_to_bounding_box(images, offset_y, offset_x,
crop_size, crop_size)
else:
images = tf.image.crop_to_bounding_box(images, max_y // 2, max_x // 2,
crop_size, crop_size)
return images