diff --git a/transformer_network.py b/transformer_network.py index e31bab8..6cd3244 100644 --- a/transformer_network.py +++ b/transformer_network.py @@ -300,7 +300,6 @@ class TransformerNetwork(network.Network): if outer_rank == 1: # This is an inference call # run transformer in loop to produce action tokens one-by-one - # TODO(b/231896343): Document/comment more on what the following mess is. seq_idx = tf.reshape(network_state['seq_idx'], [1])[0] action_t = tf.minimum(seq_idx, self._time_sequence_length - 1) # Transformer shifts all to the left by one step by default (it's usually @@ -560,7 +559,6 @@ class TransformerNetwork(network.Network): time_step = tf.minimum(seq_idx, self._time_sequence_length - 1) image = tf.expand_dims(image, 1) - # TODO(b/255731285) image_shape = tf.shape(image) b = image_shape[0] input_t = image_shape[1] @@ -612,9 +610,6 @@ class TransformerNetwork(network.Network): def _tokenize_actions(self, observations, network_state): outer_rank = self._get_outer_rank(observations) if outer_rank == 1: # This is an inference call - # TODO(b/231896343): Clarify what is going on with the network state - # tensors, currently they all have to be the same n_dims so we have to - # add/remove dummy dims. action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4]) seq_idx = tf.reshape(network_state['seq_idx'], [1])[0] # network_state as input for this call is the output from the last call.