230 lines
9.2 KiB
Python
230 lines
9.2 KiB
Python
|
# Copyright 2022 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Tests for networks."""
|
||
|
|
||
|
from absl.testing import parameterized
|
||
|
|
||
|
from robotics_transformer import transformer_network
|
||
|
from robotics_transformer.transformer_network_test_set_up import BATCH_SIZE
|
||
|
from robotics_transformer.transformer_network_test_set_up import NAME_TO_INF_OBSERVATIONS
|
||
|
from robotics_transformer.transformer_network_test_set_up import NAME_TO_STATE_SPECS
|
||
|
from robotics_transformer.transformer_network_test_set_up import observations_list
|
||
|
from robotics_transformer.transformer_network_test_set_up import spec_names_list
|
||
|
from robotics_transformer.transformer_network_test_set_up import state_spec_list
|
||
|
from robotics_transformer.transformer_network_test_set_up import TIME_SEQUENCE_LENGTH
|
||
|
from robotics_transformer.transformer_network_test_set_up import TransformerNetworkTestUtils
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from tf_agents.specs import tensor_spec
|
||
|
|
||
|
|
||
|
class TransformerNetworkTest(TransformerNetworkTestUtils):
|
||
|
|
||
|
# pylint:disable=g-complex-comprehension
|
||
|
@parameterized.named_parameters([{
|
||
|
'testcase_name': '_' + name,
|
||
|
'state_spec': spec,
|
||
|
'train_observation': obs,
|
||
|
} for (name, spec,
|
||
|
obs) in zip(spec_names_list(), state_spec_list(), observations_list())]
|
||
|
)
|
||
|
# pylint:enable=g-complex-comprehension
|
||
|
def testTransformerTrainLossCall(self, state_spec, train_observation):
|
||
|
network = transformer_network.TransformerNetwork(
|
||
|
input_tensor_spec=state_spec,
|
||
|
output_tensor_spec=self._action_spec,
|
||
|
time_sequence_length=TIME_SEQUENCE_LENGTH)
|
||
|
|
||
|
network.create_variables()
|
||
|
self.assertNotEmpty(network.variables)
|
||
|
|
||
|
network.set_actions(self._train_action)
|
||
|
network_state = tensor_spec.sample_spec_nest(
|
||
|
network.state_spec, outer_dims=[BATCH_SIZE])
|
||
|
output_actions, network_state = network(
|
||
|
train_observation, step_type=None, network_state=network_state)
|
||
|
expected_shape = [2, 3]
|
||
|
self.assertEqual(network.get_actor_loss().shape,
|
||
|
tf.TensorShape(expected_shape))
|
||
|
self.assertCountEqual(self._train_action.keys(), output_actions.keys())
|
||
|
|
||
|
# pylint:disable=g-complex-comprehension
|
||
|
@parameterized.named_parameters([{
|
||
|
'testcase_name': '_' + name,
|
||
|
'spec_name': name,
|
||
|
} for name in spec_names_list()])
|
||
|
# pylint:enable=g-complex-comprehension
|
||
|
def testTransformerInferenceLossCall(self, spec_name):
|
||
|
state_spec = NAME_TO_STATE_SPECS[spec_name]
|
||
|
observation = NAME_TO_INF_OBSERVATIONS[spec_name]
|
||
|
|
||
|
network = transformer_network.TransformerNetwork(
|
||
|
input_tensor_spec=state_spec,
|
||
|
output_tensor_spec=self._action_spec,
|
||
|
time_sequence_length=TIME_SEQUENCE_LENGTH,
|
||
|
action_order=[
|
||
|
'terminate_episode', 'world_vector', 'rotation_delta',
|
||
|
'gripper_closedness_action'
|
||
|
])
|
||
|
network.create_variables()
|
||
|
self.assertNotEmpty(network.variables)
|
||
|
|
||
|
network.set_actions(self._inference_action)
|
||
|
# inference currently only support batch size of 1
|
||
|
network_state = tensor_spec.sample_spec_nest(
|
||
|
network.state_spec, outer_dims=[1])
|
||
|
|
||
|
output_actions, network_state = network(
|
||
|
observation, step_type=None, network_state=network_state)
|
||
|
|
||
|
tf.debugging.assert_equal(network.get_actor_loss(), 0.0)
|
||
|
self.assertCountEqual(self._inference_action.keys(), output_actions.keys())
|
||
|
|
||
|
# pylint:disable=g-complex-comprehension
|
||
|
@parameterized.named_parameters([{
|
||
|
'testcase_name': '_' + name,
|
||
|
'state_spec': spec,
|
||
|
'train_observation': obs,
|
||
|
} for name, spec, obs in zip(spec_names_list(), state_spec_list(),
|
||
|
observations_list())])
|
||
|
# pylint:enable=g-complex-comprehension
|
||
|
def testTransformerLogging(self, state_spec, train_observation):
|
||
|
network = transformer_network.TransformerNetwork(
|
||
|
input_tensor_spec=state_spec,
|
||
|
output_tensor_spec=self._action_spec,
|
||
|
time_sequence_length=TIME_SEQUENCE_LENGTH,
|
||
|
action_order=[
|
||
|
'terminate_episode', 'world_vector', 'rotation_delta',
|
||
|
'gripper_closedness_action'
|
||
|
])
|
||
|
|
||
|
network.create_variables()
|
||
|
self.assertNotEmpty(network.variables)
|
||
|
|
||
|
network.set_actions(self._train_action)
|
||
|
network_state = tensor_spec.sample_spec_nest(
|
||
|
network.state_spec, outer_dims=[BATCH_SIZE])
|
||
|
_ = network(train_observation, step_type=None, network_state=network_state)
|
||
|
network.add_summaries(
|
||
|
train_observation,
|
||
|
network.get_aux_info(),
|
||
|
debug_summaries=True,
|
||
|
training=True)
|
||
|
|
||
|
# pylint:disable=g-complex-comprehension
|
||
|
@parameterized.named_parameters([{
|
||
|
'testcase_name': '_' + name,
|
||
|
'state_spec': spec,
|
||
|
} for name, spec in zip(spec_names_list(), state_spec_list())])
|
||
|
# pylint:enable=g-complex-comprehension
|
||
|
def testTransformerCausality(self, state_spec):
|
||
|
"""Tests the causality for the transformer.
|
||
|
|
||
|
Args:
|
||
|
state_spec: Which state spec to test the transformer with
|
||
|
"""
|
||
|
network = transformer_network.TransformerNetwork(
|
||
|
input_tensor_spec=state_spec,
|
||
|
output_tensor_spec=self._action_spec,
|
||
|
time_sequence_length=TIME_SEQUENCE_LENGTH)
|
||
|
network.create_variables()
|
||
|
self.assertNotEmpty(network.variables)
|
||
|
|
||
|
time_sequence_length = network._time_sequence_length
|
||
|
tokens_per_image = network._tokens_per_context_image
|
||
|
tokens_per_action = network._tokens_per_action
|
||
|
|
||
|
def _split_image_and_action_tokens(all_tokens):
|
||
|
image_start_indices = [(tokens_per_image + tokens_per_action) * k
|
||
|
for k in range(time_sequence_length)]
|
||
|
image_tokens = tf.stack(
|
||
|
[all_tokens[i:i + tokens_per_image] for i in image_start_indices],
|
||
|
axis=0)
|
||
|
action_start_indices = [i + tokens_per_image for i in image_start_indices]
|
||
|
action_tokens = [
|
||
|
tf.stack([
|
||
|
all_tokens[i:i + tokens_per_action] for i in action_start_indices
|
||
|
], 0)
|
||
|
]
|
||
|
image_tokens = tf.one_hot(image_tokens, network._token_embedding_size)
|
||
|
# Remove extra dimension before the end once b/254902773 is fixed.
|
||
|
shape = image_tokens.shape
|
||
|
# Add batch dimension.
|
||
|
image_tokens = tf.reshape(image_tokens,
|
||
|
[1] + shape[:-1] + [1] + shape[-1:])
|
||
|
return image_tokens, action_tokens
|
||
|
|
||
|
# Generate some random tokens for image and actions.
|
||
|
all_tokens = tf.random.uniform(
|
||
|
shape=[time_sequence_length * (tokens_per_image + tokens_per_action)],
|
||
|
dtype=tf.int32,
|
||
|
maxval=10,
|
||
|
minval=0)
|
||
|
context_image_tokens, action_tokens = _split_image_and_action_tokens(
|
||
|
all_tokens)
|
||
|
# Get the output tokens without any zeroed out input tokens.
|
||
|
output_tokens = network._transformer_call(
|
||
|
context_image_tokens=context_image_tokens,
|
||
|
action_tokens=action_tokens,
|
||
|
attention_mask=network._default_attention_mask,
|
||
|
batch_size=1,
|
||
|
training=False)[0]
|
||
|
|
||
|
for t in range(time_sequence_length *
|
||
|
(tokens_per_image + tokens_per_action)):
|
||
|
# Zero out future input tokens.
|
||
|
all_tokens_at_t = tf.concat(
|
||
|
[all_tokens[:t + 1],
|
||
|
tf.zeros_like(all_tokens[t + 1:])], 0)
|
||
|
context_image_tokens, action_tokens = _split_image_and_action_tokens(
|
||
|
all_tokens_at_t)
|
||
|
# Get the output tokens with zeroed out input tokens after t.
|
||
|
output_tokens_at_t = network._transformer_call(
|
||
|
context_image_tokens=context_image_tokens,
|
||
|
action_tokens=action_tokens,
|
||
|
attention_mask=network._default_attention_mask,
|
||
|
batch_size=1,
|
||
|
training=False)[0]
|
||
|
# The output token is unchanged if future input tokens are zeroed out.
|
||
|
self.assertAllEqual(output_tokens[:t + 1], output_tokens_at_t[:t + 1])
|
||
|
|
||
|
def testLossMasks(self):
|
||
|
self._define_specs()
|
||
|
self._create_agent()
|
||
|
image_tokens = 3
|
||
|
action_tokens = 2
|
||
|
self._agent._actor_network._time_sequence_length = 2
|
||
|
self._agent._actor_network._tokens_per_context_image = image_tokens
|
||
|
self._agent._actor_network._tokens_per_action = action_tokens
|
||
|
self._agent._actor_network._generate_masks()
|
||
|
self.assertAllEqual(
|
||
|
self._agent._actor_network._action_tokens_mask,
|
||
|
tf.constant([
|
||
|
image_tokens, image_tokens + 1, 2 * image_tokens + action_tokens,
|
||
|
2 * image_tokens + action_tokens + 1
|
||
|
], tf.int32))
|
||
|
self._agent._actor_network._generate_masks()
|
||
|
self.assertAllEqual(
|
||
|
self._agent._actor_network._action_tokens_mask,
|
||
|
tf.constant([
|
||
|
image_tokens, image_tokens + 1, 2 * (image_tokens) + action_tokens,
|
||
|
2 * (image_tokens) + action_tokens + 1
|
||
|
], tf.int32))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# Useful to enable if running with ipdb.
|
||
|
tf.config.run_functions_eagerly(True)
|
||
|
tf.test.main()
|