170 lines
6.0 KiB
Python
170 lines
6.0 KiB
Python
|
# Copyright 2022 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""RT1 decoder transformer.
|
||
|
|
||
|
Copied from:
|
||
|
https://www.tensorflow.org/text/tutorials/transformer#decoder
|
||
|
"""
|
||
|
from typing import Tuple, Union
|
||
|
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
class _TransformerLayer(tf.keras.layers.Layer):
|
||
|
"""A single transformer block."""
|
||
|
|
||
|
def __init__(self,
|
||
|
layer_size: int = 4096,
|
||
|
num_heads: int = 8,
|
||
|
feed_forward_size: int = 512,
|
||
|
dropout_rate: float = 0.1,
|
||
|
return_attention_scores: bool = False):
|
||
|
"""Creates a Transformer layer.
|
||
|
|
||
|
Args:
|
||
|
layer_size: Size of the multiple head attention layer.
|
||
|
num_heads: Number of heads for the multiple head attention layer.
|
||
|
feed_forward_size: Dimensionality of the feed_forward layer.
|
||
|
dropout_rate: Dropout rate.
|
||
|
return_attention_scores: Return attention scores.
|
||
|
"""
|
||
|
super(_TransformerLayer, self).__init__()
|
||
|
|
||
|
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
||
|
self.mha1 = tf.keras.layers.MultiHeadAttention(
|
||
|
key_dim=layer_size, num_heads=num_heads, dropout=dropout_rate)
|
||
|
self.ff = tf.keras.layers.Dense(feed_forward_size)
|
||
|
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
||
|
self.dropout_ff = tf.keras.layers.Dropout(dropout_rate)
|
||
|
self._return_attention_scores = return_attention_scores
|
||
|
|
||
|
def call(self, x: tf.Tensor, attention_mask: tf.Tensor,
|
||
|
training: bool) -> Tuple[tf.Tensor, Union[tf.Tensor, None]]:
|
||
|
"""Calls the layer.
|
||
|
|
||
|
Args:
|
||
|
x: Input Tensor of shape `(B, T, dim)`.
|
||
|
attention_mask: a boolean mask of shape `(B, T, T)`, that prevents
|
||
|
attention to certain positions. The boolean mask specifies which query
|
||
|
elements can attend to which key elements, 1 indicates attention and 0
|
||
|
indicates no attention. Broadcasting can happen for the missing batch
|
||
|
dimensions and the head dimension.
|
||
|
training: Python boolean indicating whether the layer should behave in
|
||
|
training mode (adding dropout) or in inference mode (no dropout).
|
||
|
|
||
|
Returns:
|
||
|
y: Output Tensor of shape `(B, T, dim)`. Also return the attention scores
|
||
|
of shape `(B, T, dim)` or None.
|
||
|
"""
|
||
|
x1 = self.layernorm1(x)
|
||
|
mha_results = self.mha1(
|
||
|
query=x1,
|
||
|
key=x1,
|
||
|
value=x1,
|
||
|
attention_mask=attention_mask,
|
||
|
return_attention_scores=self._return_attention_scores,
|
||
|
training=training)
|
||
|
if self._return_attention_scores:
|
||
|
x1, score = mha_results
|
||
|
else:
|
||
|
x1, score = mha_results, None
|
||
|
|
||
|
x = x + x1
|
||
|
|
||
|
y = self.layernorm2(x)
|
||
|
ff_y = self.ff(y)
|
||
|
ff_y = self.dropout_ff(ff_y, training=training)
|
||
|
x = x + ff_y
|
||
|
return x, score
|
||
|
|
||
|
|
||
|
class Transformer(tf.keras.layers.Layer):
|
||
|
"""A decoder only transformer."""
|
||
|
|
||
|
def __init__(self,
|
||
|
num_layers: int = 1,
|
||
|
layer_size: int = 4096,
|
||
|
num_heads: int = 8,
|
||
|
feed_forward_size: int = 512,
|
||
|
dropout_rate: float = 0.1,
|
||
|
vocab_size: int = 256,
|
||
|
return_attention_scores: bool = False):
|
||
|
"""Creates a transformer.
|
||
|
|
||
|
Args:
|
||
|
num_layers: Number of transformer layers.
|
||
|
layer_size: Size of the multiple head attention layer.
|
||
|
num_heads: Number of heads for the multiple head attention layer.
|
||
|
feed_forward_size: Dimensionality of the feed_forward layer.
|
||
|
dropout_rate: Dropout rate.
|
||
|
vocab_size: Dimensionality of tokens from the output layer.
|
||
|
return_attention_scores: Return attention scores.
|
||
|
"""
|
||
|
super(Transformer, self).__init__()
|
||
|
|
||
|
self._layers = [
|
||
|
_TransformerLayer( # pylint: disable=g-complex-comprehension
|
||
|
layer_size=layer_size,
|
||
|
num_heads=num_heads,
|
||
|
feed_forward_size=feed_forward_size,
|
||
|
dropout_rate=dropout_rate,
|
||
|
return_attention_scores=return_attention_scores)
|
||
|
for _ in range(num_layers)
|
||
|
]
|
||
|
self._token_emb = tf.keras.layers.Dense(feed_forward_size)
|
||
|
self._position_emb = tf.keras.layers.Dense(feed_forward_size)
|
||
|
self._output_tokens = tf.keras.layers.Dense(vocab_size)
|
||
|
|
||
|
def call(
|
||
|
self,
|
||
|
x: tf.Tensor,
|
||
|
training: bool,
|
||
|
attention_mask: tf.Tensor,
|
||
|
) -> Union[tf.Tensor, Tuple[tf.Tensor, list[tf.Tensor]]]:
|
||
|
"""Calls the layer.
|
||
|
|
||
|
Args:
|
||
|
x: Input Tensor of shape `(B, T, dim)`.
|
||
|
training: Python boolean indicating whether the layer should behave in
|
||
|
training mode (adding dropout) or in inference mode (no dropout).
|
||
|
attention_mask: a boolean mask of shape `(B, T, T)`, that prevents
|
||
|
attention to certain positions. The boolean mask specifies which query
|
||
|
elements can attend to which key elements, 1 indicates attention and 0
|
||
|
indicates no attention. Broadcasting can happen for the missing batch
|
||
|
dimensions and the head dimension.
|
||
|
|
||
|
Returns:
|
||
|
x: Output Tensor of shape `(B, T, vocab_size)`. If
|
||
|
`return_attention_scores`, also return attention scores of
|
||
|
a list of `layer` of elements with shape `(B, T, dim)`.
|
||
|
"""
|
||
|
|
||
|
seq_len = tf.shape(x)[1]
|
||
|
batch_size = tf.shape(x)[0]
|
||
|
|
||
|
positions = tf.one_hot(
|
||
|
tf.tile(tf.expand_dims(tf.range(0, seq_len, 1), 0), [batch_size, 1]),
|
||
|
seq_len)
|
||
|
|
||
|
x = self._token_emb(x)
|
||
|
x += self._position_emb(positions)
|
||
|
scores = []
|
||
|
|
||
|
for layer in self._layers:
|
||
|
x, score = layer(x, attention_mask=attention_mask, training=training)
|
||
|
if score is not None:
|
||
|
scores.append(score)
|
||
|
x = self._output_tokens(x)
|
||
|
return x, scores
|