robotics_transformer/configs/transformer_mixin.gin

28 lines
743 B
Plaintext

from __gin__ import dynamic_registration
from robotics_transformer import transformer_network
from robotics_transformer.tokenizers import image_tokenizer
import tensorflow as tf
LEARNING_RATE_ACTOR = 0.0001
SEQUENCE_LENGTH = 6
transformer_network.TransformerNetwork:
num_layers = 8
layer_size = 128
num_heads = 8
feed_forward_size = 512
dropout_rate = 0.1
vocab_size = 256
token_embedding_size = 512
time_sequence_length = %SEQUENCE_LENGTH
crop_size = %CROP_SIZE
action_order = %ACTION_ORDER
use_token_learner = True
actor_optimizer/tf.keras.optimizers.Adam:
learning_rate = %LEARNING_RATE_ACTOR
ACTOR_NETWORK = @transformer_network.TransformerNetwork
ACTOR_OPTIMIZER = @actor_optimizer/tf.keras.optimizers.Adam()