robotics_transformer/transformer_network_test.py

230 lines
9.2 KiB
Python
Raw Permalink Normal View History

2022-12-10 03:58:47 +08:00
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for networks."""
from absl.testing import parameterized
from robotics_transformer import transformer_network
from robotics_transformer.transformer_network_test_set_up import BATCH_SIZE
from robotics_transformer.transformer_network_test_set_up import NAME_TO_INF_OBSERVATIONS
from robotics_transformer.transformer_network_test_set_up import NAME_TO_STATE_SPECS
from robotics_transformer.transformer_network_test_set_up import observations_list
from robotics_transformer.transformer_network_test_set_up import spec_names_list
from robotics_transformer.transformer_network_test_set_up import state_spec_list
from robotics_transformer.transformer_network_test_set_up import TIME_SEQUENCE_LENGTH
from robotics_transformer.transformer_network_test_set_up import TransformerNetworkTestUtils
import tensorflow as tf
from tf_agents.specs import tensor_spec
class TransformerNetworkTest(TransformerNetworkTestUtils):
# pylint:disable=g-complex-comprehension
@parameterized.named_parameters([{
'testcase_name': '_' + name,
'state_spec': spec,
'train_observation': obs,
} for (name, spec,
obs) in zip(spec_names_list(), state_spec_list(), observations_list())]
)
# pylint:enable=g-complex-comprehension
def testTransformerTrainLossCall(self, state_spec, train_observation):
network = transformer_network.TransformerNetwork(
input_tensor_spec=state_spec,
output_tensor_spec=self._action_spec,
time_sequence_length=TIME_SEQUENCE_LENGTH)
network.create_variables()
self.assertNotEmpty(network.variables)
network.set_actions(self._train_action)
network_state = tensor_spec.sample_spec_nest(
network.state_spec, outer_dims=[BATCH_SIZE])
output_actions, network_state = network(
train_observation, step_type=None, network_state=network_state)
expected_shape = [2, 3]
self.assertEqual(network.get_actor_loss().shape,
tf.TensorShape(expected_shape))
self.assertCountEqual(self._train_action.keys(), output_actions.keys())
# pylint:disable=g-complex-comprehension
@parameterized.named_parameters([{
'testcase_name': '_' + name,
'spec_name': name,
} for name in spec_names_list()])
# pylint:enable=g-complex-comprehension
def testTransformerInferenceLossCall(self, spec_name):
state_spec = NAME_TO_STATE_SPECS[spec_name]
observation = NAME_TO_INF_OBSERVATIONS[spec_name]
network = transformer_network.TransformerNetwork(
input_tensor_spec=state_spec,
output_tensor_spec=self._action_spec,
time_sequence_length=TIME_SEQUENCE_LENGTH,
action_order=[
'terminate_episode', 'world_vector', 'rotation_delta',
'gripper_closedness_action'
])
network.create_variables()
self.assertNotEmpty(network.variables)
network.set_actions(self._inference_action)
# inference currently only support batch size of 1
network_state = tensor_spec.sample_spec_nest(
network.state_spec, outer_dims=[1])
output_actions, network_state = network(
observation, step_type=None, network_state=network_state)
tf.debugging.assert_equal(network.get_actor_loss(), 0.0)
self.assertCountEqual(self._inference_action.keys(), output_actions.keys())
# pylint:disable=g-complex-comprehension
@parameterized.named_parameters([{
'testcase_name': '_' + name,
'state_spec': spec,
'train_observation': obs,
} for name, spec, obs in zip(spec_names_list(), state_spec_list(),
observations_list())])
# pylint:enable=g-complex-comprehension
def testTransformerLogging(self, state_spec, train_observation):
network = transformer_network.TransformerNetwork(
input_tensor_spec=state_spec,
output_tensor_spec=self._action_spec,
time_sequence_length=TIME_SEQUENCE_LENGTH,
action_order=[
'terminate_episode', 'world_vector', 'rotation_delta',
'gripper_closedness_action'
])
network.create_variables()
self.assertNotEmpty(network.variables)
network.set_actions(self._train_action)
network_state = tensor_spec.sample_spec_nest(
network.state_spec, outer_dims=[BATCH_SIZE])
_ = network(train_observation, step_type=None, network_state=network_state)
network.add_summaries(
train_observation,
network.get_aux_info(),
debug_summaries=True,
training=True)
# pylint:disable=g-complex-comprehension
@parameterized.named_parameters([{
'testcase_name': '_' + name,
'state_spec': spec,
} for name, spec in zip(spec_names_list(), state_spec_list())])
# pylint:enable=g-complex-comprehension
def testTransformerCausality(self, state_spec):
"""Tests the causality for the transformer.
Args:
state_spec: Which state spec to test the transformer with
"""
network = transformer_network.TransformerNetwork(
input_tensor_spec=state_spec,
output_tensor_spec=self._action_spec,
time_sequence_length=TIME_SEQUENCE_LENGTH)
network.create_variables()
self.assertNotEmpty(network.variables)
time_sequence_length = network._time_sequence_length
tokens_per_image = network._tokens_per_context_image
tokens_per_action = network._tokens_per_action
def _split_image_and_action_tokens(all_tokens):
image_start_indices = [(tokens_per_image + tokens_per_action) * k
for k in range(time_sequence_length)]
image_tokens = tf.stack(
[all_tokens[i:i + tokens_per_image] for i in image_start_indices],
axis=0)
action_start_indices = [i + tokens_per_image for i in image_start_indices]
action_tokens = [
tf.stack([
all_tokens[i:i + tokens_per_action] for i in action_start_indices
], 0)
]
image_tokens = tf.one_hot(image_tokens, network._token_embedding_size)
# Remove extra dimension before the end once b/254902773 is fixed.
shape = image_tokens.shape
# Add batch dimension.
image_tokens = tf.reshape(image_tokens,
[1] + shape[:-1] + [1] + shape[-1:])
return image_tokens, action_tokens
# Generate some random tokens for image and actions.
all_tokens = tf.random.uniform(
shape=[time_sequence_length * (tokens_per_image + tokens_per_action)],
dtype=tf.int32,
maxval=10,
minval=0)
context_image_tokens, action_tokens = _split_image_and_action_tokens(
all_tokens)
# Get the output tokens without any zeroed out input tokens.
output_tokens = network._transformer_call(
context_image_tokens=context_image_tokens,
action_tokens=action_tokens,
attention_mask=network._default_attention_mask,
batch_size=1,
training=False)[0]
for t in range(time_sequence_length *
(tokens_per_image + tokens_per_action)):
# Zero out future input tokens.
all_tokens_at_t = tf.concat(
[all_tokens[:t + 1],
tf.zeros_like(all_tokens[t + 1:])], 0)
context_image_tokens, action_tokens = _split_image_and_action_tokens(
all_tokens_at_t)
# Get the output tokens with zeroed out input tokens after t.
output_tokens_at_t = network._transformer_call(
context_image_tokens=context_image_tokens,
action_tokens=action_tokens,
attention_mask=network._default_attention_mask,
batch_size=1,
training=False)[0]
# The output token is unchanged if future input tokens are zeroed out.
self.assertAllEqual(output_tokens[:t + 1], output_tokens_at_t[:t + 1])
def testLossMasks(self):
self._define_specs()
self._create_agent()
image_tokens = 3
action_tokens = 2
self._agent._actor_network._time_sequence_length = 2
self._agent._actor_network._tokens_per_context_image = image_tokens
self._agent._actor_network._tokens_per_action = action_tokens
self._agent._actor_network._generate_masks()
self.assertAllEqual(
self._agent._actor_network._action_tokens_mask,
tf.constant([
image_tokens, image_tokens + 1, 2 * image_tokens + action_tokens,
2 * image_tokens + action_tokens + 1
], tf.int32))
self._agent._actor_network._generate_masks()
self.assertAllEqual(
self._agent._actor_network._action_tokens_mask,
tf.constant([
image_tokens, image_tokens + 1, 2 * (image_tokens) + action_tokens,
2 * (image_tokens) + action_tokens + 1
], tf.int32))
if __name__ == '__main__':
# Useful to enable if running with ipdb.
tf.config.run_functions_eagerly(True)
tf.test.main()