28 lines
743 B
Plaintext
28 lines
743 B
Plaintext
from __gin__ import dynamic_registration
|
|
from robotics_transformer import transformer_network
|
|
from robotics_transformer.tokenizers import image_tokenizer
|
|
import tensorflow as tf
|
|
|
|
LEARNING_RATE_ACTOR = 0.0001
|
|
SEQUENCE_LENGTH = 6
|
|
|
|
|
|
transformer_network.TransformerNetwork:
|
|
num_layers = 8
|
|
layer_size = 128
|
|
num_heads = 8
|
|
feed_forward_size = 512
|
|
dropout_rate = 0.1
|
|
vocab_size = 256
|
|
token_embedding_size = 512
|
|
time_sequence_length = %SEQUENCE_LENGTH
|
|
crop_size = %CROP_SIZE
|
|
action_order = %ACTION_ORDER
|
|
use_token_learner = True
|
|
|
|
actor_optimizer/tf.keras.optimizers.Adam:
|
|
learning_rate = %LEARNING_RATE_ACTOR
|
|
|
|
ACTOR_NETWORK = @transformer_network.TransformerNetwork
|
|
ACTOR_OPTIMIZER = @actor_optimizer/tf.keras.optimizers.Adam()
|