Update transformer_network.py
This commit is contained in:
parent
356139043a
commit
042b9457a3
|
@ -300,7 +300,6 @@ class TransformerNetwork(network.Network):
|
|||
|
||||
if outer_rank == 1: # This is an inference call
|
||||
# run transformer in loop to produce action tokens one-by-one
|
||||
# TODO(b/231896343): Document/comment more on what the following mess is.
|
||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||
action_t = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
||||
# Transformer shifts all to the left by one step by default (it's usually
|
||||
|
@ -560,7 +559,6 @@ class TransformerNetwork(network.Network):
|
|||
time_step = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
||||
image = tf.expand_dims(image, 1)
|
||||
|
||||
# TODO(b/255731285)
|
||||
image_shape = tf.shape(image)
|
||||
b = image_shape[0]
|
||||
input_t = image_shape[1]
|
||||
|
@ -612,9 +610,6 @@ class TransformerNetwork(network.Network):
|
|||
def _tokenize_actions(self, observations, network_state):
|
||||
outer_rank = self._get_outer_rank(observations)
|
||||
if outer_rank == 1: # This is an inference call
|
||||
# TODO(b/231896343): Clarify what is going on with the network state
|
||||
# tensors, currently they all have to be the same n_dims so we have to
|
||||
# add/remove dummy dims.
|
||||
action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4])
|
||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||
# network_state as input for this call is the output from the last call.
|
||||
|
|
Loading…
Reference in New Issue