robotics_transformer/sequence_agent.py

172 lines
6.5 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence policy and agent that directly output actions via actor network.
These classes are not intended to change as they are generic enough for any
all-neural actor based agent+policy. All new features are intended to be
implemented in `actor_network` and `loss_fn`.
"""
from typing import Optional, Type
from absl import logging
import tensorflow as tf
from tf_agents.agents import data_converter
from tf_agents.agents import tf_agent
from tf_agents.networks import network
from tf_agents.policies import actor_policy
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import nest_utils
class SequencePolicy(actor_policy.ActorPolicy):
"""A policy that directly outputs actions via an actor network."""
def __init__(self, **kwargs):
self._actions = None
super().__init__(**kwargs)
def set_actions(self, actions):
self._actor_network.set_actions(actions)
def get_actor_loss(self):
return self._actor_network.get_actor_loss()
def get_aux_info(self):
return self._actor_network.get_aux_info()
def set_training(self, training):
self._training = training
def _action(self,
time_step: ts.TimeStep,
policy_state: types.NestedTensor,
seed: Optional[types.Seed] = None) -> policy_step.PolicyStep:
del seed
action, policy_state = self._apply_actor_network(
time_step.observation,
step_type=time_step.step_type,
policy_state=policy_state)
info = ()
return policy_step.PolicyStep(action, policy_state, info)
def _distribution(self, time_step, policy_state):
current_step = super()._distribution(time_step, policy_state)
return current_step
class SequenceAgent(tf_agent.TFAgent):
"""A sequence agent that directly outputs actions via an actor network."""
def __init__(self,
time_step_spec: ts.TimeStep,
action_spec: types.NestedTensorSpec,
actor_network: Type[network.Network],
actor_optimizer: tf.keras.optimizers.Optimizer,
policy_cls: Type[actor_policy.ActorPolicy] = SequencePolicy,
time_sequence_length: int = 6,
debug_summaries: bool = False,
**kwargs):
self._info_spec = ()
self._actor_network = actor_network( # pytype: disable=missing-parameter # dynamic-method-lookup
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
policy_info_spec=self._info_spec,
train_step_counter=kwargs['train_step_counter'],
time_sequence_length=time_sequence_length)
self._actor_optimizer = actor_optimizer
# Train policy is only used for loss and never exported as saved_model.
self._train_policy = policy_cls(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=self._info_spec,
actor_network=self._actor_network,
training=True)
collect_policy = policy_cls(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=self._info_spec,
actor_network=self._actor_network,
training=False)
super(SequenceAgent, self).__init__(
time_step_spec,
action_spec,
collect_policy, # We use the collect_policy as the eval policy.
collect_policy,
train_sequence_length=time_sequence_length,
**kwargs)
self._data_context = data_converter.DataContext(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=collect_policy.info_spec,
use_half_transition=True)
self.as_transition = data_converter.AsHalfTransition(
self._data_context, squeeze_time_dim=False)
self._debug_summaries = debug_summaries
num_params = 0
for weight in self._actor_network.trainable_weights:
weight_params = 1
for dim in weight.shape:
weight_params *= dim
logging.info('%s has %s params.', weight.name, weight_params)
num_params += weight_params
logging.info('Actor network has %sM params.', round(num_params / 1000000.,
2))
def _train(self, experience: types.NestedTensor,
weights: types.Tensor) -> tf_agent.LossInfo:
self.train_step_counter.assign_add(1)
loss_info = self._loss(experience, weights, training=True)
self._apply_gradients(loss_info.loss)
return loss_info
def _apply_gradients(self, loss: types.Tensor):
variables = self._actor_network.trainable_weights
gradients = tf.gradients(loss, variables)
# Skip nan and inf gradients.
new_gradients = []
for g in gradients:
if g is not None:
new_g = tf.where(
tf.math.logical_or(tf.math.is_inf(g), tf.math.is_nan(g)),
tf.zeros_like(g), g)
new_gradients.append(new_g)
else:
new_gradients.append(g)
grads_and_vars = list(zip(new_gradients, variables))
self._actor_optimizer.apply_gradients(grads_and_vars)
def _loss(self, experience: types.NestedTensor, weights: types.Tensor,
training: bool) -> tf_agent.LossInfo:
transition = self.as_transition(experience)
time_steps, policy_steps, _ = transition
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
policy = self._train_policy
policy.set_actions(policy_steps.action)
policy.set_training(training=training)
with tf.name_scope('actor_loss'):
policy_state = policy.get_initial_state(batch_size)
policy.action(time_steps, policy_state=policy_state)
valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
loss = valid_mask * policy.get_actor_loss()
loss = tf.reduce_mean(loss)
policy.set_actions(None)
self._actor_network.add_summaries(time_steps.observation,
policy.get_aux_info(),
self._debug_summaries, training)
return tf_agent.LossInfo(loss=loss, extra=loss)