robotics_transformer/transformer_network_test_se...

392 lines
15 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for networks."""
import copy
from typing import Optional, Tuple, Union
from absl.testing import parameterized
import numpy as np
from robotics_transformer import sequence_agent
from robotics_transformer import transformer_network
from tensor2robot.utils import tensorspec_utils
import tensorflow as tf
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
BATCH_SIZE = 2
TIME_SEQUENCE_LENGTH = 3
HEIGHT = 256
WIDTH = 320
NUM_IMAGE_TOKENS = 2
def spec_names_list() -> list[str]:
"""Lists the different types of specs accepted by the transformer."""
return ['default']
def state_spec_list() -> list[tensorspec_utils.TensorSpecStruct]:
"""Lists the different types of state spec accepted by the transformer."""
state_spec = tensorspec_utils.TensorSpecStruct()
state_spec.image = tensor_spec.BoundedTensorSpec([HEIGHT, WIDTH, 3],
dtype=tf.float32,
name='image',
minimum=0.,
maximum=1.)
state_spec.natural_language_embedding = tensor_spec.TensorSpec(
shape=[512], dtype=tf.float32, name='natural_language_embedding')
state_spec_mask = copy.deepcopy(state_spec)
state_spec_mask.initial_binary_mask = tensor_spec.BoundedTensorSpec(
[HEIGHT, WIDTH, 1],
dtype=tf.int32,
name='initial_binary_mask',
minimum=0,
maximum=255)
state_spec_tcl = copy.deepcopy(state_spec)
state_spec_tcl.original_image = tensor_spec.BoundedTensorSpec(
[HEIGHT, WIDTH, 3],
dtype=tf.float32,
name='original_image',
minimum=0.,
maximum=1.)
return [
state_spec,
state_spec_mask,
state_spec_tcl,
]
def observations_list(training: bool = True) -> list[dict[str, tf.Tensor]]:
"""Lists the different types of observations accepted by the transformer."""
if training:
image_shape = [BATCH_SIZE, TIME_SEQUENCE_LENGTH, HEIGHT, WIDTH, 3]
emb_shape = [BATCH_SIZE, TIME_SEQUENCE_LENGTH, 512]
mask_shape = [BATCH_SIZE, TIME_SEQUENCE_LENGTH, HEIGHT, WIDTH, 1]
else:
# inference currently only support batch size of 1
image_shape = [1, HEIGHT, WIDTH, 3]
emb_shape = [1, 512]
mask_shape = [1, HEIGHT, WIDTH, 1]
return [
{
'image': tf.constant(0.5, shape=image_shape),
'natural_language_embedding': tf.constant(1., shape=emb_shape),
},
{
'image': tf.constant(0.5, shape=image_shape),
'natural_language_embedding': tf.constant(1., shape=emb_shape),
'initial_binary_mask': tf.constant(192, shape=mask_shape),
},
{ # This is used for TCL.
'image': tf.constant(0.5, shape=image_shape),
'original_image': tf.constant(0.4, shape=image_shape),
'natural_language_embedding': tf.constant(1., shape=emb_shape),
},
]
NAME_TO_STATE_SPECS = dict(zip(spec_names_list(), state_spec_list()))
NAME_TO_OBSERVATIONS = dict(zip(spec_names_list(), observations_list()))
NAME_TO_INF_OBSERVATIONS = dict(
zip(spec_names_list(), observations_list(False)))
class FakeImageTokenizer(tf.keras.layers.Layer):
"""Fake Image Tokenizer for testing Transformer."""
def __init__(self,
encoder: ...,
position_embedding: ...,
embedding_output_dim: int,
patch_size: int,
use_token_learner: bool = False,
num_tokens: int = NUM_IMAGE_TOKENS,
use_initial_binary_mask: bool = False,
**kwargs):
del encoder, position_embedding, patch_size, use_token_learner
super().__init__(**kwargs)
self.tokens_per_context_image = num_tokens
if use_initial_binary_mask:
self.tokens_per_context_image += 1
self.embedding_output_dim = embedding_output_dim
self.use_initial_binary_mask = use_initial_binary_mask
def __call__(self,
image: tf.Tensor,
context: Optional[tf.Tensor] = None,
initial_binary_mask: Optional[tf.Tensor] = None,
training: bool = False) -> tf.Tensor:
if self.use_initial_binary_mask:
assert initial_binary_mask is not None
image_shape = tf.shape(image)
seq_size = image_shape[1]
batch_size = image_shape[0]
all_tokens = []
num_tokens = self.tokens_per_context_image
for t in range(seq_size):
tokens = tf.ones([batch_size, 1, num_tokens, self.embedding_output_dim
]) * image[0][t][0][0]
all_tokens.append(tokens)
return tf.concat(all_tokens, axis=1)
class TransformerNetworkTestUtils(tf.test.TestCase, parameterized.TestCase):
"""Defines specs, SequenceAgent, and various other testing utilities."""
def _define_specs(self,
train_batch_size=BATCH_SIZE,
inference_batch_size=1,
time_sequence_length=TIME_SEQUENCE_LENGTH,
inference_sequence_length=TIME_SEQUENCE_LENGTH,
token_embedding_size=512,
image_width=WIDTH,
image_height=HEIGHT):
"""Defines specs and observations (both training and inference)."""
self.train_batch_size = train_batch_size
self.inference_batch_size = inference_batch_size
self.time_sequence_length = time_sequence_length
self.inference_sequence_length = inference_sequence_length
self.token_embedding_size = token_embedding_size
action_spec = tensorspec_utils.TensorSpecStruct()
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
action_spec.rotation_delta = tensor_spec.BoundedTensorSpec(
(3,),
dtype=tf.float32,
minimum=-np.pi / 2,
maximum=np.pi / 2,
name='rotation_delta')
action_spec.gripper_closedness_action = tensor_spec.BoundedTensorSpec(
(1,),
dtype=tf.float32,
minimum=-1.,
maximum=1.,
name='gripper_closedness_action')
action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
state_spec = tensorspec_utils.TensorSpecStruct()
state_spec.image = tensor_spec.BoundedTensorSpec(
[image_height, image_width, 3],
dtype=tf.float32,
name='image',
minimum=0.,
maximum=1.)
state_spec.natural_language_embedding = tensor_spec.TensorSpec(
shape=[self.token_embedding_size],
dtype=tf.float32,
name='natural_language_embedding')
self._policy_info_spec = {
'return':
tensor_spec.BoundedTensorSpec((),
dtype=tf.float32,
minimum=0.0,
maximum=1.0,
name='return'),
'discounted_return':
tensor_spec.BoundedTensorSpec((),
dtype=tf.float32,
minimum=0.0,
maximum=1.0,
name='discounted_return'),
}
self._state_spec = state_spec
self._action_spec = action_spec
self._inference_observation = {
'image':
tf.constant(
1,
shape=[self.inference_batch_size, image_height, image_width, 3],
dtype=tf.dtypes.float32),
'natural_language_embedding':
tf.constant(
1.,
shape=[self.inference_batch_size, self.token_embedding_size],
dtype=tf.dtypes.float32),
}
self._train_observation = {
'image':
tf.constant(
0.5,
shape=[
self.train_batch_size, self.time_sequence_length,
image_height, image_width, 3
]),
'natural_language_embedding':
tf.constant(
1.,
shape=[
self.train_batch_size, self.time_sequence_length,
self.token_embedding_size
]),
}
self._inference_action = {
'world_vector':
tf.constant(0.5, shape=[self.inference_batch_size, 3]),
'rotation_delta':
tf.constant(0.5, shape=[self.inference_batch_size, 3]),
'terminate_episode':
tf.constant(
[0, 1] * self.inference_batch_size,
shape=[self.inference_batch_size, 2]),
'gripper_closedness_action':
tf.constant(0.5, shape=[self.inference_batch_size, 1]),
}
self._train_action = {
'world_vector':
tf.constant(
0.5,
shape=[self.train_batch_size, self.time_sequence_length, 3]),
'rotation_delta':
tf.constant(
0.5,
shape=[self.train_batch_size, self.time_sequence_length, 3]),
'terminate_episode':
tf.constant(
[0, 1] * self.train_batch_size * self.time_sequence_length,
shape=[self.train_batch_size, self.time_sequence_length, 2]),
'gripper_closedness_action':
tf.constant(
0.5,
shape=[self.train_batch_size, self.time_sequence_length, 1]),
}
def _create_agent(self, actor_network=None):
"""Creates SequenceAgent using custom actor_network."""
time_step_spec = ts.time_step_spec(observation_spec=self._state_spec)
if actor_network is None:
actor_network = transformer_network.TransformerNetwork
self._agent = sequence_agent.SequenceAgent(
time_step_spec=time_step_spec,
action_spec=self._action_spec,
actor_network=actor_network,
actor_optimizer=tf.keras.optimizers.Adam(),
train_step_counter=tf.compat.v1.train.get_or_create_global_step(),
time_sequence_length=TIME_SEQUENCE_LENGTH)
self._num_action_tokens = (
# pylint:disable=protected-access
self._agent._actor_network._action_tokenizer._tokens_per_action)
# pylint:enable=protected-access
def setUp(self):
self._define_specs()
super().setUp()
def get_image_value(self, step_idx: int) -> float:
return float(step_idx) / self.time_sequence_length
def get_action_logits(self, batch_size: int, value: int,
vocab_size: int) -> tf.Tensor:
return tf.broadcast_to(
tf.one_hot(value % vocab_size, vocab_size)[tf.newaxis, tf.newaxis, :],
[batch_size, 1, vocab_size])
def create_obs(self, value) -> dict[str, tf.Tensor]:
observations = {}
observations['image'] = value * self._inference_observation['image']
observations[
'natural_language_embedding'] = value * self._inference_observation[
'natural_language_embedding']
return observations
def fake_action_token_emb(self, action_tokens) -> tf.Tensor:
"""Just pad with zeros."""
shape = action_tokens.shape
assert self.vocab_size > self.token_embedding_size
assert len(shape) == 4
return action_tokens[:, :, :, :self.token_embedding_size]
def fake_transformer(
self, all_tokens, training,
attention_mask) -> Union[tf.Tensor, Tuple[tf.Tensor, list[tf.Tensor]]]:
"""Fakes the call to TransformerNetwork._transformer."""
del training
del attention_mask
# We expect ST00 ST01 A00 A01...
# Where:
# * ST01 is token 1 of state 0.
# * A01 is token 1 of action 0.
shape = all_tokens.shape.as_list()
batch_size = shape[0]
self.assertEqual(batch_size, 1)
emb_size = self.token_embedding_size
# transform to [batch_size, num_tokens, token_size]
all_tokens = tf.reshape(all_tokens, [batch_size, -1, emb_size])
# Pads tokens to be of vocab_size.
self.assertGreater(self.vocab_size, self.token_embedding_size)
all_shape = all_tokens.shape
self.assertLen(all_shape.as_list(), 3)
output_tokens = tf.concat([
all_tokens,
tf.zeros([
all_shape[0], all_shape[1],
self.vocab_size - self.token_embedding_size
])
],
axis=-1)
num_tokens_per_step = NUM_IMAGE_TOKENS + self._num_action_tokens
# Check state/action alignment.
window_range = min(self._step_idx + 1, self.time_sequence_length)
for j in range(window_range):
# The index step that is stored in j = 0.
first_step_idx = max(0, self._step_idx + 1 - self.time_sequence_length)
image_idx = j * num_tokens_per_step
action_start_index = image_idx + NUM_IMAGE_TOKENS
for t in range(NUM_IMAGE_TOKENS):
self.assertAllEqual(
self.get_image_value(first_step_idx + j) *
tf.ones_like(all_tokens[0][image_idx][:self.token_embedding_size]),
all_tokens[0][image_idx + t][:self.token_embedding_size])
# if j is not the current step in the window, all action dimensions
# from previous steps are already infered and thus can be checked.
action_dims_range = self.action_inf_idx if j == window_range - 1 else self._num_action_tokens
for t in range(action_dims_range):
token_idx = action_start_index + t
action_value = (first_step_idx + j) * self._num_action_tokens + t
self.assertAllEqual(
self.get_action_logits(
batch_size=batch_size,
value=action_value,
vocab_size=self.vocab_size)[0][0][:self.token_embedding_size],
all_tokens[0][token_idx][:self.token_embedding_size])
# Output the right action dimension value.
image_token_index = (
min(self._step_idx, self.time_sequence_length - 1) *
num_tokens_per_step)
transformer_shift = -1
action_index = (
image_token_index + NUM_IMAGE_TOKENS + self.action_inf_idx +
transformer_shift)
action_value = self._step_idx * self._num_action_tokens + self.action_inf_idx
action_logits = self.get_action_logits(
batch_size=batch_size, value=action_value, vocab_size=self.vocab_size)
output_tokens = tf.concat([
output_tokens[:, :action_index, :], action_logits[:, :, :],
output_tokens[:, action_index + 1:, :]
],
axis=1)
self.action_inf_idx = (self.action_inf_idx + 1) % self._num_action_tokens
attention_scores = []
return output_tokens, attention_scores