robotics_transformer/sequence_agent_test_set_up.py

145 lines
4.9 KiB
Python
Raw Normal View History

2022-12-10 03:58:47 +08:00
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for sequence_agent."""
from typing import Type
import numpy as np
from robotics_transformer import sequence_agent
from tensor2robot.utils import tensorspec_utils
import tensorflow as tf
from tf_agents.networks import network
from tf_agents.policies import policy_saver
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
class DummyActorNet(network.Network):
"""Used for testing SequenceAgent and its subclass."""
def __init__(self,
output_tensor_spec=None,
train_step_counter=None,
policy_info_spec=None,
time_sequence_length=1,
use_tcl=False,
**kwargs):
super().__init__(**kwargs)
@property
def tokens_per_action(self):
return 8
def set_actions(self, actions):
self._actions = actions
def get_actor_loss(self):
return self._actor_loss
def call(self,
observations,
step_type,
network_state,
actions=None,
training=False):
del step_type
image = observations['image']
tf.expand_dims(tf.reduce_mean(image, axis=-1), -1)
actions = tensorspec_utils.TensorSpecStruct(
world_vector=tf.constant(1., shape=[1, 3]),
rotation_delta=tf.constant(1., shape=[1, 3]),
terminate_episode=tf.constant(1, shape=[1, 2]),
gripper_closedness_action=tf.constant(1., shape=[1, 1]),
)
return actions, network_state
@property
def trainable_weights(self):
return [tf.Variable(1.0)]
class SequenceAgentTestSetUp(tf.test.TestCase):
"""Defines spec for testing SequenceAgent and its subclass, tests create."""
def setUp(self):
super().setUp()
self._action_spec = tensorspec_utils.TensorSpecStruct()
self._action_spec.world_vector = tensor_spec.BoundedTensorSpec(
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
self._action_spec.rotation_delta = tensor_spec.BoundedTensorSpec(
(3,),
dtype=tf.float32,
minimum=-np.pi / 2,
maximum=np.pi / 2,
name='rotation_delta')
self._action_spec.gripper_closedness_action = tensor_spec.BoundedTensorSpec(
(1,),
dtype=tf.float32,
minimum=-1.,
maximum=1.,
name='gripper_closedness_action')
self._action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
state_spec = tensorspec_utils.TensorSpecStruct()
state_spec.image = tensor_spec.BoundedTensorSpec([256, 320, 3],
dtype=tf.float32,
name='image',
minimum=0.,
maximum=1.)
state_spec.natural_language_embedding = tensor_spec.TensorSpec(
shape=[512], dtype=tf.float32, name='natural_language_embedding')
self._time_step_spec = ts.time_step_spec(observation_spec=state_spec)
self.sequence_agent_cls = sequence_agent.SequenceAgent
def create_agent_and_initialize(self,
actor_network: Type[
network.Network] = DummyActorNet,
**kwargs):
"""Creates the agent and initialize it."""
agent = self.sequence_agent_cls(
time_step_spec=self._time_step_spec,
action_spec=self._action_spec,
actor_network=actor_network,
actor_optimizer=tf.keras.optimizers.Adam(),
train_step_counter=tf.compat.v1.train.get_or_create_global_step(),
**kwargs)
agent.initialize()
return agent
def testCreateAgent(self):
"""Creates the Agent and save the agent.policy."""
agent = self.create_agent_and_initialize()
self.assertIsNotNone(agent.policy)
policy_model_saver = policy_saver.PolicySaver(
agent.policy,
train_step=tf.compat.v2.Variable(
0,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=()),
input_fn_and_spec=None)
save_options = tf.saved_model.SaveOptions(
experimental_io_device='/job:localhost',
experimental_custom_gradients=False)
policy_model_saver.save('/tmp/unittest/policy/0', options=save_options)
if __name__ == '__main__':
tf.test.main()