Update transformer_network.py

This commit is contained in:
P G Keerthana Gopalakrishnan 2022-12-12 14:02:35 -08:00 committed by GitHub
parent 356139043a
commit 042b9457a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 5 deletions

View File

@ -300,7 +300,6 @@ class TransformerNetwork(network.Network):
if outer_rank == 1: # This is an inference call if outer_rank == 1: # This is an inference call
# run transformer in loop to produce action tokens one-by-one # run transformer in loop to produce action tokens one-by-one
# TODO(b/231896343): Document/comment more on what the following mess is.
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0] seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
action_t = tf.minimum(seq_idx, self._time_sequence_length - 1) action_t = tf.minimum(seq_idx, self._time_sequence_length - 1)
# Transformer shifts all to the left by one step by default (it's usually # Transformer shifts all to the left by one step by default (it's usually
@ -560,7 +559,6 @@ class TransformerNetwork(network.Network):
time_step = tf.minimum(seq_idx, self._time_sequence_length - 1) time_step = tf.minimum(seq_idx, self._time_sequence_length - 1)
image = tf.expand_dims(image, 1) image = tf.expand_dims(image, 1)
# TODO(b/255731285)
image_shape = tf.shape(image) image_shape = tf.shape(image)
b = image_shape[0] b = image_shape[0]
input_t = image_shape[1] input_t = image_shape[1]
@ -612,9 +610,6 @@ class TransformerNetwork(network.Network):
def _tokenize_actions(self, observations, network_state): def _tokenize_actions(self, observations, network_state):
outer_rank = self._get_outer_rank(observations) outer_rank = self._get_outer_rank(observations)
if outer_rank == 1: # This is an inference call if outer_rank == 1: # This is an inference call
# TODO(b/231896343): Clarify what is going on with the network state
# tensors, currently they all have to be the same n_dims so we have to
# add/remove dummy dims.
action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4]) action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4])
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0] seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
# network_state as input for this call is the output from the last call. # network_state as input for this call is the output from the last call.