Update transformer_network.py
This commit is contained in:
parent
356139043a
commit
042b9457a3
|
@ -300,7 +300,6 @@ class TransformerNetwork(network.Network):
|
||||||
|
|
||||||
if outer_rank == 1: # This is an inference call
|
if outer_rank == 1: # This is an inference call
|
||||||
# run transformer in loop to produce action tokens one-by-one
|
# run transformer in loop to produce action tokens one-by-one
|
||||||
# TODO(b/231896343): Document/comment more on what the following mess is.
|
|
||||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||||
action_t = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
action_t = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
||||||
# Transformer shifts all to the left by one step by default (it's usually
|
# Transformer shifts all to the left by one step by default (it's usually
|
||||||
|
@ -560,7 +559,6 @@ class TransformerNetwork(network.Network):
|
||||||
time_step = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
time_step = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
||||||
image = tf.expand_dims(image, 1)
|
image = tf.expand_dims(image, 1)
|
||||||
|
|
||||||
# TODO(b/255731285)
|
|
||||||
image_shape = tf.shape(image)
|
image_shape = tf.shape(image)
|
||||||
b = image_shape[0]
|
b = image_shape[0]
|
||||||
input_t = image_shape[1]
|
input_t = image_shape[1]
|
||||||
|
@ -612,9 +610,6 @@ class TransformerNetwork(network.Network):
|
||||||
def _tokenize_actions(self, observations, network_state):
|
def _tokenize_actions(self, observations, network_state):
|
||||||
outer_rank = self._get_outer_rank(observations)
|
outer_rank = self._get_outer_rank(observations)
|
||||||
if outer_rank == 1: # This is an inference call
|
if outer_rank == 1: # This is an inference call
|
||||||
# TODO(b/231896343): Clarify what is going on with the network state
|
|
||||||
# tensors, currently they all have to be the same n_dims so we have to
|
|
||||||
# add/remove dummy dims.
|
|
||||||
action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4])
|
action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4])
|
||||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||||
# network_state as input for this call is the output from the last call.
|
# network_state as input for this call is the output from the last call.
|
||||||
|
|
Loading…
Reference in New Issue