add robotics transformer
This commit is contained in:
commit
ccd9d2fe2d
|
@ -0,0 +1,25 @@
|
|||
# Compiled python modules.
|
||||
*.pyc
|
||||
|
||||
# Byte-compiled
|
||||
_pycache__/
|
||||
.cache/
|
||||
|
||||
# Poetry, setuptools, PyPI distribution artifacts.
|
||||
/*.egg-info
|
||||
.eggs/
|
||||
build/
|
||||
dist/
|
||||
poetry.lock
|
||||
|
||||
# Tests
|
||||
.pytest_cache/
|
||||
|
||||
# Type checking
|
||||
.pytype/
|
||||
|
||||
# Other
|
||||
*.DS_Store
|
||||
|
||||
# PyCharm
|
||||
.idea
|
|
@ -0,0 +1,447 @@
|
|||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Add files or directories to the ignore list. They should be base names, not
|
||||
# paths.
|
||||
ignore=third_party
|
||||
|
||||
# Add files or directories matching the regex patterns to the ignore list. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code.
|
||||
extension-pkg-allow-list=
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat-in-sequence,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-function-docstring,
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Put messages in a separate file for each module / package specified on the
|
||||
# command line instead of printing them on stdout. Reports (if any) will be
|
||||
# written in a file name "pylint_global.[txt|html]". This option is deprecated
|
||||
# and it will be removed in Pylint 2.0.
|
||||
files-output=no
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=StandardError,
|
||||
Exception,
|
||||
BaseException
|
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"files.insertFinalNewline": true,
|
||||
"files.trimFinalNewlines": true,
|
||||
"files.trimTrailingWhitespace": true,
|
||||
"files.associations": {
|
||||
".pylintrc": "ini"
|
||||
},
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.nosetestsEnabled": false,
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.linting.pylintUseMinimalCheckers": false,
|
||||
"[python]": {
|
||||
"editor.rulers": [
|
||||
80
|
||||
],
|
||||
"editor.tabSize": 2,
|
||||
"editor.formatOnSave": true,
|
||||
"editor.detectIndentation": false
|
||||
},
|
||||
"python.formatting.provider": "black",
|
||||
"python.formatting.blackPath": "pyink",
|
||||
"files.watcherExclude": {
|
||||
"**/.git/**": true
|
||||
},
|
||||
"files.exclude": {
|
||||
"**/__pycache__": true,
|
||||
"**/.pytest_cache": true,
|
||||
"**/*.egg-info": true
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
# Changelog
|
||||
|
||||
<!--
|
||||
|
||||
Changelog follow the https://keepachangelog.com/ standard (at least the headers)
|
||||
|
||||
This allow to:
|
||||
|
||||
* auto-parsing release notes during the automated releases from github-action:
|
||||
https://github.com/marketplace/actions/pypi-github-auto-release
|
||||
* Have clickable headers in the rendered markdown
|
||||
|
||||
To release a new version (e.g. from `1.0.0` -> `2.0.0`):
|
||||
|
||||
* Create a new `# [2.0.0] - YYYY-MM-DD` header and add the current
|
||||
`[Unreleased]` notes.
|
||||
* At the end of the file:
|
||||
* Define the new link url:
|
||||
`[2.0.0]: https://github.com/google-research/robotics_transformer/compare/v1.0.0...v2.0.0`
|
||||
* Update the `[Unreleased]` url: `v1.0.0...HEAD` -> `v2.0.0...HEAD`
|
||||
|
||||
-->
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.1.0] - 2022-01-01
|
||||
|
||||
* Initial release
|
||||
|
||||
[Unreleased]: https://github.com/google-research/robotics_transformer/compare/v0.1.0...HEAD
|
||||
[0.1.0]: https://github.com/google-research/robotics_transformer/releases/tag/v0.1.0
|
|
@ -0,0 +1,29 @@
|
|||
# How to Contribute
|
||||
|
||||
We'd love to accept your patches and contributions to this project. There are
|
||||
just a few small guidelines you need to follow.
|
||||
|
||||
## Contributor License Agreement
|
||||
|
||||
Contributions to this project must be accompanied by a Contributor License
|
||||
Agreement (CLA). You (or your employer) retain the copyright to your
|
||||
contribution; this simply gives us permission to use and redistribute your
|
||||
contributions as part of the project. Head over to
|
||||
<https://cla.developers.google.com/> to see your current agreements on file or
|
||||
to sign a new one.
|
||||
|
||||
You generally only need to submit a CLA once, so if you've already submitted one
|
||||
(even if it was for a different project), you probably don't need to do it
|
||||
again.
|
||||
|
||||
## Code Reviews
|
||||
|
||||
All submissions, including submissions by project members, require review. We
|
||||
use GitHub pull requests for this purpose. Consult
|
||||
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
||||
information on using pull requests.
|
||||
|
||||
## Community Guidelines
|
||||
|
||||
This project follows
|
||||
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,134 @@
|
|||
# go/google3metadata
|
||||
# proto-file: devtools/metadata/metadata.proto
|
||||
# proto-message: MetaData
|
||||
|
||||
name: "robotics_transformer"
|
||||
description: "code and utilities to build and run RT-1 robotics transformer"
|
||||
|
||||
third_party {
|
||||
url {
|
||||
type: HOMEPAGE
|
||||
value: "http://go/robotics_transformer"
|
||||
}
|
||||
url {
|
||||
type: PIPER
|
||||
value: "http://google3/third_party/py/robotics_transformer"
|
||||
}
|
||||
}
|
||||
|
||||
presubmit: {
|
||||
review_notify: "robotics_transformer-automated+reviews"
|
||||
|
||||
check_tests: {
|
||||
failure_status: ERROR
|
||||
project: "robotics_transformer"
|
||||
}
|
||||
|
||||
# Checks that files in the changelist do not contain tab characters.
|
||||
check_tabs: {
|
||||
failure_status: ERROR
|
||||
}
|
||||
|
||||
check_trailing_whitespace: {
|
||||
failure_status: ERROR
|
||||
}
|
||||
|
||||
# Presubmit applied during submit
|
||||
|
||||
check_lint: {
|
||||
action: SUBMIT # Do not activate by default to not block TAP.
|
||||
failure_status: ERROR
|
||||
}
|
||||
|
||||
# Ensures that the string "do not submit" (in all caps) is not present.
|
||||
check_do_not_submit: {
|
||||
action: SUBMIT
|
||||
}
|
||||
}
|
||||
|
||||
# Register the copy.bara.sky
|
||||
exported: {
|
||||
copybara: {
|
||||
config_path: "//depot/google3/third_party/py/robotics_transformer/copy.bara.sky"
|
||||
}
|
||||
path_expression: "//depot/google3/third_party/py/robotics_transformer/..."
|
||||
remote_location: "https://github.com/google-research/robotics_transformer"
|
||||
reason: OPEN_SOURCE
|
||||
description: "Open source robotics_transformer"
|
||||
# request_url: "https://launch.corp.google.com/launch/4225970"
|
||||
owning_team_email: "robotics_transformer-automated@google.com"
|
||||
}
|
||||
|
||||
# Copybara presubmit
|
||||
# presubmit: {
|
||||
# path_expression: "//depot/google3/third_party/py/robotics_transformer/..."
|
||||
# # Do not trigger copybara for the following files
|
||||
# path_expression_exclusion: "//depot/.../METADATA"
|
||||
# path_expression_exclusion: "//depot/.../OWNERS"
|
||||
# path_expression_exclusion: "//depot/.../BUILD"
|
||||
# path_expression_exclusion: "//depot/.../*.bzl"
|
||||
# path_expression_exclusion: "//depot/.../google/..."
|
||||
|
||||
# # Ensure that changes contain public notes for git commit messages.
|
||||
# check_description: {
|
||||
# base: {
|
||||
# id: "CopybaraDescription"
|
||||
# disable_tags: "GIT_ORIGIN_REV_ID"
|
||||
# disable_tags: "SKIP_COPYBARA"
|
||||
# }
|
||||
|
||||
# required_regexp:
|
||||
# "("
|
||||
# "(^|\\n)\\s*BEGIN_PUBLIC\\s*?\\n"
|
||||
# "(.*\\n)*"
|
||||
# "\\s*\\S+.*(\\n.*)*\\n"
|
||||
# "\\s*END_PUBLIC\\s*?\\n"
|
||||
# "|"
|
||||
# "(^|\\n)\\s*PUBLIC:(?: )*\\S+"
|
||||
# ")"
|
||||
|
||||
# failure_message:
|
||||
# "\n"
|
||||
# "By running presubmit, this cl will be exported as PR on github. "
|
||||
# "Please add a public commit message to the cl description:\n"
|
||||
# "\n"
|
||||
# "PUBLIC: my public commit msg\n"
|
||||
# "\n"
|
||||
# "OR\n"
|
||||
# "\n"
|
||||
# "BEGIN_PUBLIC\n"
|
||||
# "my public\n"
|
||||
# "commit msg\n"
|
||||
# "END_PUBLIC\n"
|
||||
# "\n"
|
||||
# "If you're certain your change does not produce public changes, the\n"
|
||||
# "message can say 'Internal'.\n"
|
||||
# failure_status: WARNING
|
||||
# required_for_cleanup: false
|
||||
# }
|
||||
|
||||
# check_presubmit_service: {
|
||||
# base: { id: "Copybara-Review" disable_tags: "GIT_ORIGIN_REV_ID" }
|
||||
# action: REVIEW
|
||||
# streaming: true
|
||||
# timeout: 60
|
||||
# failure_status: WARNING
|
||||
# execution_mode: SECONDARY_EXECUTION
|
||||
# include_all_opened_files: true
|
||||
# include_deleted_files: true
|
||||
# address: "blade:copybara-streaming-presubmit-service-prod"
|
||||
# options: "depot_path=//depot/google3/third_party/py/robotics_transformer/copy.bara.sky;workflow=piper_to_github_presubmit;blocking=false"
|
||||
# }
|
||||
# check_presubmit_service: {
|
||||
# base: { id: "Copybara-Submit" disable_tags: "GIT_ORIGIN_REV_ID" }
|
||||
# action: SUBMIT
|
||||
# streaming: true
|
||||
# timeout: 600
|
||||
# failure_status: ERROR
|
||||
# execution_mode: SECONDARY_EXECUTION
|
||||
# include_all_opened_files: true
|
||||
# include_deleted_files: true
|
||||
# address: "blade:copybara-streaming-presubmit-service-prod"
|
||||
# options: "depot_path=//depot/google3/third_party/py/robotics_transformer/copy.bara.sky;workflow=piper_to_github_presubmit;blocking=true"
|
||||
# }
|
||||
# }
|
|
@ -0,0 +1,45 @@
|
|||
# Robotics Transformer
|
||||
|
||||
*This is not an officially supported Google product.*
|
||||
|
||||
|
||||
This repository is a collection code files and artifacts for running
|
||||
Robotics Transformer or RT-1.
|
||||
|
||||
## Features
|
||||
|
||||
* Film efficient net based image tokenizer backbone
|
||||
* Token learner based compression of input tokens
|
||||
* Transformer for end to end robotic control
|
||||
* Testing utilities
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Installation
|
||||
Clone the repo
|
||||
```bash
|
||||
git clone https://github.com/google-research/robotics_transformer.git
|
||||
pip install -r robotics_transformer/requirements.txt
|
||||
python -m robotics_transformer.tokenizers.action_tokenizer.test
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
To run RT-1 tests, you can clone the git repo and run
|
||||
[bazel](https://bazel.build/):
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google_research/robotics_transformer.git
|
||||
cd robotics_transformer
|
||||
bazel test ...
|
||||
```
|
||||
|
||||
## Future Releases
|
||||
|
||||
The current repository includes an initial set of libraries for early adoption.
|
||||
More components may come in future releases.
|
||||
|
||||
## License
|
||||
|
||||
The Robotics Transformer library is licensed under the terms of the Apache
|
||||
license.
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""robotics_transformer API."""
|
||||
|
||||
# A new PyPI release will be pushed everytime `__version__` is increased
|
||||
# When changing this, also update the CHANGELOG.md
|
||||
__version__ = '0.1.0'
|
|
@ -0,0 +1,27 @@
|
|||
from __gin__ import dynamic_registration
|
||||
from robotics_transformer import transformer_network
|
||||
from robotics_transformer.tokenizers import image_tokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
LEARNING_RATE_ACTOR = 0.0001
|
||||
SEQUENCE_LENGTH = 6
|
||||
|
||||
|
||||
transformer_network.TransformerNetwork:
|
||||
num_layers = 8
|
||||
layer_size = 128
|
||||
num_heads = 8
|
||||
feed_forward_size = 512
|
||||
dropout_rate = 0.1
|
||||
vocab_size = 256
|
||||
token_embedding_size = 512
|
||||
time_sequence_length = %SEQUENCE_LENGTH
|
||||
crop_size = %CROP_SIZE
|
||||
action_order = %ACTION_ORDER
|
||||
use_token_learner = True
|
||||
|
||||
actor_optimizer/tf.keras.optimizers.Adam:
|
||||
learning_rate = %LEARNING_RATE_ACTOR
|
||||
|
||||
ACTOR_NETWORK = @transformer_network.TransformerNetwork
|
||||
ACTOR_OPTIMIZER = @actor_optimizer/tf.keras.optimizers.Adam()
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ResNet variants model for Keras with Film-Conditioning.
|
||||
|
||||
Related papers/blogs:
|
||||
- https://arxiv.org/abs/1512.03385
|
||||
- https://arxiv.org/pdf/1603.05027v2.pdf
|
||||
- http://torch.ch/blog/2016/02/04/resnets.html
|
||||
- https://arxiv.org/abs/1709.07871
|
||||
"""
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
layers = tf.keras.layers
|
||||
|
||||
|
||||
class FilmConditioning(tf.keras.layers.Layer):
|
||||
"""Layer that adds FiLM conditioning.
|
||||
|
||||
This is intended to be applied after a convolutional layer. It will learn a
|
||||
multiplicative and an additive factor to be applied to each channel of the
|
||||
convolution's output.
|
||||
|
||||
Conv layer can be rank 2 or 4.
|
||||
|
||||
For further details, see: https://arxiv.org/abs/1709.07871
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels: int):
|
||||
"""Constructs a FiLM conditioning layer.
|
||||
|
||||
Args:
|
||||
num_channels: Number of filter channels to expect in the input.
|
||||
"""
|
||||
super().__init__()
|
||||
# Note that we initialize with zeros because empirically we have found
|
||||
# this works better than initializing with glorot.
|
||||
self._projection_add = layers.Dense(
|
||||
num_channels,
|
||||
activation=None,
|
||||
kernel_initializer='zeros',
|
||||
bias_initializer='zeros')
|
||||
self._projection_mult = layers.Dense(
|
||||
num_channels,
|
||||
activation=None,
|
||||
kernel_initializer='zeros',
|
||||
bias_initializer='zeros')
|
||||
|
||||
def call(self, conv_filters: tf.Tensor, conditioning: tf.Tensor):
|
||||
tf.debugging.assert_rank(conditioning, 2)
|
||||
projected_cond_add = self._projection_add(conditioning)
|
||||
projected_cond_mult = self._projection_mult(conditioning)
|
||||
|
||||
if len(conv_filters.shape) == 4:
|
||||
# [B, D] -> [B, 1, 1, D]
|
||||
projected_cond_add = projected_cond_add[:, tf.newaxis, tf.newaxis]
|
||||
projected_cond_mult = projected_cond_mult[:, tf.newaxis, tf.newaxis]
|
||||
else:
|
||||
tf.debugging.assert_rank(conv_filters, 2)
|
||||
|
||||
# Original FiLM paper argues that 1 + gamma centers the initialization at
|
||||
# identity transform.
|
||||
result = (1 + projected_cond_mult) * conv_filters + projected_cond_add
|
||||
return result
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for film_conditioning_layer."""
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from robotics_transformer.film_efficientnet import film_conditioning_layer
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class FilmConditioningLayerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([2, 4])
|
||||
def test_film_conditioning_rank_two_and_four(self, conv_rank):
|
||||
batch = 2
|
||||
num_channels = 3
|
||||
if conv_rank == 2:
|
||||
conv_layer = np.random.randn(batch, num_channels)
|
||||
elif conv_rank == 4:
|
||||
conv_layer = np.random.randn(batch, 1, 1, num_channels)
|
||||
else:
|
||||
raise ValueError(f'Unexpected conv rank: {conv_rank}')
|
||||
context = np.random.rand(batch, num_channels)
|
||||
film_layer = film_conditioning_layer.FilmConditioning(num_channels)
|
||||
out = film_layer(conv_layer, context)
|
||||
tf.debugging.assert_rank(out, conv_rank)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,759 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pytype: skip-file
|
||||
# pylint: skip-file
|
||||
"""EfficientNet models modified with added film layers.
|
||||
|
||||
Mostly copied from third_party/py/keras/applications/efficientnet.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
import json
|
||||
|
||||
from absl import logging
|
||||
import tensorflow.compat.v2 as tf
|
||||
from tensorflow.keras import layers
|
||||
|
||||
from robotics_transformer.film_efficientnet.film_conditioning_layer import FilmConditioning
|
||||
|
||||
BASE_WEIGHTS_PATH = 'efficientnet_checkpoints/efficientnet'
|
||||
IMAGENET_JSON_PATH = 'efficientnet_checkpoints/imagenet_classes.json'
|
||||
CLASS_INDEX = None
|
||||
|
||||
WEIGHTS_PATHS = {
|
||||
'efficientnetb3': BASE_WEIGHTS_PATH + 'b3.h5',
|
||||
'efficientnetb3_notop': BASE_WEIGHTS_PATH + 'b3_notop.h5',
|
||||
}
|
||||
|
||||
DEFAULT_BLOCKS_ARGS = [{
|
||||
'kernel_size': 3,
|
||||
'repeats': 1,
|
||||
'filters_in': 32,
|
||||
'filters_out': 16,
|
||||
'expand_ratio': 1,
|
||||
'id_skip': True,
|
||||
'strides': 1,
|
||||
'se_ratio': 0.25
|
||||
}, {
|
||||
'kernel_size': 3,
|
||||
'repeats': 2,
|
||||
'filters_in': 16,
|
||||
'filters_out': 24,
|
||||
'expand_ratio': 6,
|
||||
'id_skip': True,
|
||||
'strides': 2,
|
||||
'se_ratio': 0.25
|
||||
}, {
|
||||
'kernel_size': 5,
|
||||
'repeats': 2,
|
||||
'filters_in': 24,
|
||||
'filters_out': 40,
|
||||
'expand_ratio': 6,
|
||||
'id_skip': True,
|
||||
'strides': 2,
|
||||
'se_ratio': 0.25
|
||||
}, {
|
||||
'kernel_size': 3,
|
||||
'repeats': 3,
|
||||
'filters_in': 40,
|
||||
'filters_out': 80,
|
||||
'expand_ratio': 6,
|
||||
'id_skip': True,
|
||||
'strides': 2,
|
||||
'se_ratio': 0.25
|
||||
}, {
|
||||
'kernel_size': 5,
|
||||
'repeats': 3,
|
||||
'filters_in': 80,
|
||||
'filters_out': 112,
|
||||
'expand_ratio': 6,
|
||||
'id_skip': True,
|
||||
'strides': 1,
|
||||
'se_ratio': 0.25
|
||||
}, {
|
||||
'kernel_size': 5,
|
||||
'repeats': 4,
|
||||
'filters_in': 112,
|
||||
'filters_out': 192,
|
||||
'expand_ratio': 6,
|
||||
'id_skip': True,
|
||||
'strides': 2,
|
||||
'se_ratio': 0.25
|
||||
}, {
|
||||
'kernel_size': 3,
|
||||
'repeats': 1,
|
||||
'filters_in': 192,
|
||||
'filters_out': 320,
|
||||
'expand_ratio': 6,
|
||||
'id_skip': True,
|
||||
'strides': 1,
|
||||
'se_ratio': 0.25
|
||||
}]
|
||||
|
||||
CONV_KERNEL_INITIALIZER = {
|
||||
'class_name': 'VarianceScaling',
|
||||
'config': {
|
||||
'scale': 2.0,
|
||||
'mode': 'fan_out',
|
||||
'distribution': 'truncated_normal'
|
||||
}
|
||||
}
|
||||
|
||||
DENSE_KERNEL_INITIALIZER = {
|
||||
'class_name': 'VarianceScaling',
|
||||
'config': {
|
||||
'scale': 1. / 3.,
|
||||
'mode': 'fan_out',
|
||||
'distribution': 'uniform'
|
||||
}
|
||||
}
|
||||
|
||||
BASE_DOCSTRING = """Instantiates the {name} architecture.
|
||||
|
||||
Reference:
|
||||
- [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
|
||||
https://arxiv.org/abs/1905.11946) (ICML 2019)
|
||||
|
||||
This function returns a Keras image classification model,
|
||||
optionally loaded with weights pre-trained on ImageNet.
|
||||
|
||||
For image classification use cases, see
|
||||
[this page for detailed examples](
|
||||
https://keras.io/api/applications/#usage-examples-for-image-classification-models).
|
||||
|
||||
For transfer learning use cases, make sure to read the
|
||||
[guide to transfer learning & fine-tuning](
|
||||
https://keras.io/guides/transfer_learning/).
|
||||
|
||||
Note: each Keras Application expects a specific kind of input preprocessing.
|
||||
For EfficientNet, input preprocessing is included as part of the model
|
||||
(as a `Rescaling` layer), and thus
|
||||
`tf.keras.applications.efficientnet.preprocess_input` is actually a
|
||||
pass-through function. EfficientNet models expect their inputs to be float
|
||||
tensors of pixels with values in the [0-255] range.
|
||||
|
||||
Args:
|
||||
include_top: Whether to include the fully-connected
|
||||
layer at the top of the network. Defaults to True.
|
||||
weights: One of `None` (random initialization),
|
||||
'imagenet' (pre-training on ImageNet),
|
||||
or the path to the weights file to be loaded. Defaults to 'imagenet'.
|
||||
input_tensor: Optional Keras tensor
|
||||
(i.e. output of `layers.Input()`)
|
||||
to use as image input for the model.
|
||||
input_shape: Optional shape tuple, only to be specified
|
||||
if `include_top` is False.
|
||||
It should have exactly 3 inputs channels.
|
||||
pooling: Optional pooling mode for feature extraction
|
||||
when `include_top` is `False`. Defaults to None.
|
||||
- `None` means that the output of the model will be
|
||||
the 4D tensor output of the
|
||||
last convolutional layer.
|
||||
- `avg` means that global average pooling
|
||||
will be applied to the output of the
|
||||
last convolutional layer, and thus
|
||||
the output of the model will be a 2D tensor.
|
||||
- `max` means that global max pooling will
|
||||
be applied.
|
||||
classes: Optional number of classes to classify images
|
||||
into, only to be specified if `include_top` is True, and
|
||||
if no `weights` argument is specified. Defaults to 1000 (number of
|
||||
ImageNet classes).
|
||||
classifier_activation: A `str` or callable. The activation function to use
|
||||
on the "top" layer. Ignored unless `include_top=True`. Set
|
||||
`classifier_activation=None` to return the logits of the "top" layer.
|
||||
Defaults to 'softmax'.
|
||||
When loading pretrained weights, `classifier_activation` can only
|
||||
be `None` or `"softmax"`.
|
||||
|
||||
Returns:
|
||||
A `keras.Model` instance.
|
||||
"""
|
||||
|
||||
IMAGENET_STDDEV_RGB = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
def validate_activation(classifier_activation, weights):
|
||||
"""validates that the classifier is compatible with the weights.
|
||||
|
||||
Args:
|
||||
classifier_activation: str or callable activation function
|
||||
weights: The pretrained weights to load.
|
||||
|
||||
Raises:
|
||||
ValueError: if an activation other than `None` or `softmax` are used with
|
||||
pretrained weights.
|
||||
"""
|
||||
if weights is None:
|
||||
return
|
||||
|
||||
classifier_activation = tf.keras.activations.get(classifier_activation)
|
||||
if classifier_activation not in {
|
||||
tf.keras.activations.get('softmax'),
|
||||
tf.keras.activations.get(None)
|
||||
}:
|
||||
raise ValueError('Only `None` and `softmax` activations are allowed '
|
||||
'for the `classifier_activation` argument when using '
|
||||
'pretrained weights, with `include_top=True`; Received: '
|
||||
f'classifier_activation={classifier_activation}')
|
||||
|
||||
|
||||
def correct_pad(inputs, kernel_size):
|
||||
"""Returns a tuple for zero-padding for 2D convolution with downsampling.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor.
|
||||
kernel_size: An integer or tuple/list of 2 integers.
|
||||
|
||||
Returns:
|
||||
A tuple.
|
||||
"""
|
||||
img_dim = 2 if tf.keras.backend.image_data_format() == 'channels_first' else 1
|
||||
input_size = tf.keras.backend.int_shape(inputs)[img_dim:(img_dim + 2)]
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if input_size[0] is None:
|
||||
adjust = (1, 1)
|
||||
else:
|
||||
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
|
||||
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||
return ((correct[0] - adjust[0], correct[0]), (correct[1] - adjust[1],
|
||||
correct[1]))
|
||||
|
||||
|
||||
def obtain_input_shape(input_shape,
|
||||
default_size,
|
||||
min_size,
|
||||
data_format,
|
||||
require_flatten,
|
||||
weights=None):
|
||||
"""Internal utility to compute/validate a model's input shape.
|
||||
|
||||
Args:
|
||||
input_shape: Either None (will return the default network input shape), or a
|
||||
user-provided shape to be validated.
|
||||
default_size: Default input width/height for the model.
|
||||
min_size: Minimum input width/height accepted by the model.
|
||||
data_format: Image data format to use.
|
||||
require_flatten: Whether the model is expected to be linked to a classifier
|
||||
via a Flatten layer.
|
||||
weights: One of `None` (random initialization) or 'imagenet' (pre-training
|
||||
on ImageNet). If weights='imagenet' input channels must be equal to 3.
|
||||
|
||||
Returns:
|
||||
An integer shape tuple (may include None entries).
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid argument values.
|
||||
"""
|
||||
if weights != 'imagenet' and input_shape and len(input_shape) == 3:
|
||||
if data_format == 'channels_first':
|
||||
if input_shape[0] not in {1, 3}:
|
||||
warnings.warn(
|
||||
'This model usually expects 1 or 3 input channels. '
|
||||
'However, it was passed an input_shape with ' +
|
||||
str(input_shape[0]) + ' input channels.',
|
||||
stacklevel=2)
|
||||
default_shape = (input_shape[0], default_size, default_size)
|
||||
else:
|
||||
if input_shape[-1] not in {1, 3}:
|
||||
warnings.warn(
|
||||
'This model usually expects 1 or 3 input channels. '
|
||||
'However, it was passed an input_shape with ' +
|
||||
str(input_shape[-1]) + ' input channels.',
|
||||
stacklevel=2)
|
||||
default_shape = (default_size, default_size, input_shape[-1])
|
||||
else:
|
||||
if data_format == 'channels_first':
|
||||
default_shape = (3, default_size, default_size)
|
||||
else:
|
||||
default_shape = (default_size, default_size, 3)
|
||||
if weights == 'imagenet' and require_flatten:
|
||||
if input_shape is not None:
|
||||
if input_shape != default_shape:
|
||||
raise ValueError('When setting `include_top=True` '
|
||||
'and loading `imagenet` weights, '
|
||||
f'`input_shape` should be {default_shape}. '
|
||||
f'Received: input_shape={input_shape}')
|
||||
return default_shape
|
||||
if input_shape:
|
||||
if data_format == 'channels_first':
|
||||
if input_shape is not None:
|
||||
if len(input_shape) != 3:
|
||||
raise ValueError('`input_shape` must be a tuple of three integers.')
|
||||
if input_shape[0] != 3 and weights == 'imagenet':
|
||||
raise ValueError('The input must have 3 channels; Received '
|
||||
f'`input_shape={input_shape}`')
|
||||
if ((input_shape[1] is not None and input_shape[1] < min_size) or
|
||||
(input_shape[2] is not None and input_shape[2] < min_size)):
|
||||
raise ValueError(f'Input size must be at least {min_size}'
|
||||
f'x{min_size}; Received: '
|
||||
f'input_shape={input_shape}')
|
||||
else:
|
||||
if input_shape is not None:
|
||||
if len(input_shape) != 3:
|
||||
raise ValueError('`input_shape` must be a tuple of three integers.')
|
||||
if input_shape[-1] != 3 and weights == 'imagenet':
|
||||
raise ValueError('The input must have 3 channels; Received '
|
||||
f'`input_shape={input_shape}`')
|
||||
if ((input_shape[0] is not None and input_shape[0] < min_size) or
|
||||
(input_shape[1] is not None and input_shape[1] < min_size)):
|
||||
raise ValueError('Input size must be at least '
|
||||
f'{min_size}x{min_size}; Received: '
|
||||
f'input_shape={input_shape}')
|
||||
else:
|
||||
if require_flatten:
|
||||
input_shape = default_shape
|
||||
else:
|
||||
if data_format == 'channels_first':
|
||||
input_shape = (3, None, None)
|
||||
else:
|
||||
input_shape = (None, None, 3)
|
||||
if require_flatten:
|
||||
if None in input_shape:
|
||||
raise ValueError('If `include_top` is True, '
|
||||
'you should specify a static `input_shape`. '
|
||||
f'Received: input_shape={input_shape}')
|
||||
return input_shape
|
||||
|
||||
|
||||
def EfficientNet(width_coefficient,
|
||||
depth_coefficient,
|
||||
default_size,
|
||||
dropout_rate=0.2,
|
||||
drop_connect_rate=0.2,
|
||||
depth_divisor=8,
|
||||
activation='swish',
|
||||
blocks_args='default',
|
||||
model_name='efficientnet',
|
||||
include_top=True,
|
||||
weights='imagenet',
|
||||
input_tensor=None,
|
||||
input_shape=None,
|
||||
pooling=None,
|
||||
classes=1000,
|
||||
classifier_activation='softmax',
|
||||
include_film=False):
|
||||
"""Instantiates the EfficientNet architecture using given scaling coefficients.
|
||||
|
||||
Args:
|
||||
width_coefficient: float, scaling coefficient for network width.
|
||||
depth_coefficient: float, scaling coefficient for network depth.
|
||||
default_size: integer, default input image size.
|
||||
dropout_rate: float, dropout rate before final classifier layer.
|
||||
drop_connect_rate: float, dropout rate at skip connections.
|
||||
depth_divisor: integer, a unit of network width.
|
||||
activation: activation function.
|
||||
blocks_args: list of dicts, parameters to construct block modules.
|
||||
model_name: string, model name.
|
||||
include_top: whether to include the fully-connected layer at the top of the
|
||||
network.
|
||||
weights: one of `None` (random initialization), 'imagenet' (pre-training on
|
||||
ImageNet), or the path to the weights file to be loaded.
|
||||
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use
|
||||
as image input for the model.
|
||||
input_shape: optional shape tuple, only to be specified if `include_top` is
|
||||
False. It should have exactly 3 inputs channels.
|
||||
pooling: optional pooling mode for feature extraction when `include_top` is
|
||||
`False`. - `None` means that the output of the model will be the 4D tensor
|
||||
output of the last convolutional layer. - `avg` means that global average
|
||||
pooling will be applied to the output of the last convolutional layer, and
|
||||
thus the output of the model will be a 2D tensor. - `max` means that
|
||||
global max pooling will be applied.
|
||||
classes: optional number of classes to classify images into, only to be
|
||||
specified if `include_top` is True, and if no `weights` argument is
|
||||
specified.
|
||||
classifier_activation: A `str` or callable. The activation function to use
|
||||
on the "top" layer. Ignored unless `include_top=True`. Set
|
||||
`classifier_activation=None` to return the logits of the "top" layer.
|
||||
include_film: bool, whether or not to insert film conditioning layers.
|
||||
|
||||
Returns:
|
||||
A `keras.Model` instance.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid argument for `weights`,
|
||||
or invalid input shape.
|
||||
ValueError: if `classifier_activation` is not `softmax` or `None` when
|
||||
using a pretrained top layer.
|
||||
"""
|
||||
if blocks_args == 'default':
|
||||
blocks_args = DEFAULT_BLOCKS_ARGS
|
||||
|
||||
if not (weights in {'imagenet', None} or tf.io.gfile.exists(weights)):
|
||||
raise ValueError('The `weights` argument should be either '
|
||||
'`None` (random initialization), `imagenet` '
|
||||
'(pre-training on ImageNet), '
|
||||
'or the path to the weights file to be loaded.')
|
||||
|
||||
if weights == 'imagenet' and include_top and classes != 1000:
|
||||
raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
|
||||
' as true, `classes` should be 1000')
|
||||
|
||||
# Determine proper input shape
|
||||
input_shape = obtain_input_shape(
|
||||
input_shape,
|
||||
default_size=default_size,
|
||||
min_size=32,
|
||||
data_format=tf.keras.backend.image_data_format(),
|
||||
require_flatten=include_top,
|
||||
weights=weights)
|
||||
|
||||
if include_film:
|
||||
with tf.compat.v1.variable_scope('context_input'):
|
||||
context_input = layers.Input(shape=512)
|
||||
if input_tensor is None:
|
||||
img_input = layers.Input(shape=input_shape)
|
||||
else:
|
||||
if not tf.keras.backend.is_keras_tensor(input_tensor):
|
||||
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
|
||||
else:
|
||||
img_input = input_tensor
|
||||
|
||||
bn_axis = 3 if tf.keras.backend.image_data_format() == 'channels_last' else 1
|
||||
|
||||
def round_filters(filters, divisor=depth_divisor):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
filters *= width_coefficient
|
||||
new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_filters < 0.9 * filters:
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
def round_repeats(repeats):
|
||||
"""Round number of repeats based on depth multiplier."""
|
||||
return int(math.ceil(depth_coefficient * repeats))
|
||||
|
||||
# Build stem
|
||||
x = img_input
|
||||
x = layers.Rescaling(1. / 255.)(x)
|
||||
x = layers.Normalization(axis=bn_axis)(x)
|
||||
# Note that the normaliztion layer uses square value of STDDEV as the
|
||||
# variance for the layer: result = (input - mean) / sqrt(var)
|
||||
# However, the original implemenetation uses (input - mean) / var to
|
||||
# normalize the input, we need to divide another sqrt(var) to match the
|
||||
# original implementation.
|
||||
# See https://github.com/tensorflow/tensorflow/issues/49930 for more details
|
||||
# We always apply this transformation, even when not using imagenet weights,
|
||||
# because it needs to be in the graph when grafting weights from imagenet
|
||||
# pretrained models.
|
||||
x = layers.Rescaling(1. / tf.math.sqrt(IMAGENET_STDDEV_RGB))(x)
|
||||
|
||||
x = layers.ZeroPadding2D(padding=correct_pad(x, 3), name='stem_conv_pad')(x)
|
||||
x = layers.Conv2D(
|
||||
round_filters(32),
|
||||
3,
|
||||
strides=2,
|
||||
padding='valid',
|
||||
use_bias=False,
|
||||
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name='stem_conv')(
|
||||
x)
|
||||
x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
|
||||
x = layers.Activation(activation, name='stem_activation')(x)
|
||||
|
||||
# Build blocks
|
||||
blocks_args = copy.deepcopy(blocks_args)
|
||||
|
||||
b = 0
|
||||
blocks = float(sum(round_repeats(args['repeats']) for args in blocks_args))
|
||||
for (i, args) in enumerate(blocks_args):
|
||||
assert args['repeats'] > 0
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
args['filters_in'] = round_filters(args['filters_in'])
|
||||
args['filters_out'] = round_filters(args['filters_out'])
|
||||
|
||||
for j in range(round_repeats(args.pop('repeats'))):
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
if j > 0:
|
||||
args['strides'] = 1
|
||||
args['filters_in'] = args['filters_out']
|
||||
x = block(
|
||||
x,
|
||||
activation,
|
||||
drop_connect_rate * b / blocks,
|
||||
name='block{}{}_'.format(i + 1, chr(j + 97)),
|
||||
**args)
|
||||
if include_film:
|
||||
with tf.compat.v1.variable_scope('film_conditioning'):
|
||||
x = FilmConditioning(num_channels=x.shape[-1])(x, context_input)
|
||||
b += 1
|
||||
|
||||
# Build top
|
||||
x = layers.Conv2D(
|
||||
round_filters(1280),
|
||||
1,
|
||||
padding='same',
|
||||
use_bias=False,
|
||||
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name='top_conv')(
|
||||
x)
|
||||
x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
|
||||
x = layers.Activation(activation, name='top_activation')(x)
|
||||
if include_top:
|
||||
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
|
||||
if dropout_rate > 0:
|
||||
x = layers.Dropout(dropout_rate, name='top_dropout')(x)
|
||||
validate_activation(classifier_activation, weights)
|
||||
x = layers.Dense(
|
||||
classes,
|
||||
activation=classifier_activation,
|
||||
kernel_initializer=DENSE_KERNEL_INITIALIZER,
|
||||
name='predictions')(
|
||||
x)
|
||||
else:
|
||||
if pooling == 'avg':
|
||||
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
|
||||
elif pooling == 'max':
|
||||
x = layers.GlobalMaxPooling2D(name='max_pool')(x)
|
||||
|
||||
# Ensure that the model takes into account
|
||||
# any potential predecessors of `input_tensor`.
|
||||
if input_tensor is not None:
|
||||
inputs = tf.keras.utils.get_source_inputs(input_tensor)
|
||||
else:
|
||||
inputs = img_input
|
||||
if include_film:
|
||||
inputs = (img_input, context_input)
|
||||
|
||||
# Create model.
|
||||
model = tf.keras.Model(inputs, x, name=model_name)
|
||||
|
||||
# Load weights.
|
||||
if weights == 'imagenet':
|
||||
if include_top:
|
||||
key = model_name
|
||||
else:
|
||||
key = model_name + '_notop'
|
||||
weights_path = os.path.join(os.path.dirname(__file__), WEIGHTS_PATHS[key])
|
||||
model.load_weights(weights_path, skip_mismatch=False, by_name=False)
|
||||
elif weights is not None:
|
||||
model.load_weights(weights, skip_mismatch=False, by_name=False)
|
||||
return model
|
||||
|
||||
|
||||
def block(inputs,
|
||||
activation='swish',
|
||||
drop_rate=0.,
|
||||
name='',
|
||||
filters_in=32,
|
||||
filters_out=16,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
expand_ratio=1,
|
||||
se_ratio=0.,
|
||||
id_skip=True):
|
||||
"""An inverted residual block.
|
||||
|
||||
Args:
|
||||
inputs: input tensor.
|
||||
activation: activation function.
|
||||
drop_rate: float between 0 and 1, fraction of the input units to drop.
|
||||
name: string, block label.
|
||||
filters_in: integer, the number of input filters.
|
||||
filters_out: integer, the number of output filters.
|
||||
kernel_size: integer, the dimension of the convolution window.
|
||||
strides: integer, the stride of the convolution.
|
||||
expand_ratio: integer, scaling coefficient for the input filters.
|
||||
se_ratio: float between 0 and 1, fraction to squeeze the input filters.
|
||||
id_skip: boolean.
|
||||
|
||||
Returns:
|
||||
output tensor for the block.
|
||||
"""
|
||||
bn_axis = 3 if tf.keras.backend.image_data_format() == 'channels_last' else 1
|
||||
|
||||
# Expansion phase
|
||||
filters = filters_in * expand_ratio
|
||||
if expand_ratio != 1:
|
||||
x = layers.Conv2D(
|
||||
filters,
|
||||
1,
|
||||
padding='same',
|
||||
use_bias=False,
|
||||
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name=name + 'expand_conv')(
|
||||
inputs)
|
||||
x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
|
||||
x = layers.Activation(activation, name=name + 'expand_activation')(x)
|
||||
else:
|
||||
x = inputs
|
||||
|
||||
# Depthwise Convolution
|
||||
if strides == 2:
|
||||
x = layers.ZeroPadding2D(
|
||||
padding=correct_pad(x, kernel_size), name=name + 'dwconv_pad')(
|
||||
x)
|
||||
conv_pad = 'valid'
|
||||
else:
|
||||
conv_pad = 'same'
|
||||
x = layers.DepthwiseConv2D(
|
||||
kernel_size,
|
||||
strides=strides,
|
||||
padding=conv_pad,
|
||||
use_bias=False,
|
||||
depthwise_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name=name + 'dwconv')(
|
||||
x)
|
||||
x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
|
||||
x = layers.Activation(activation, name=name + 'activation')(x)
|
||||
|
||||
# Squeeze and Excitation phase
|
||||
if 0 < se_ratio <= 1:
|
||||
filters_se = max(1, int(filters_in * se_ratio))
|
||||
se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x)
|
||||
if bn_axis == 1:
|
||||
se_shape = (filters, 1, 1)
|
||||
else:
|
||||
se_shape = (1, 1, filters)
|
||||
se = layers.Reshape(se_shape, name=name + 'se_reshape')(se)
|
||||
se = layers.Conv2D(
|
||||
filters_se,
|
||||
1,
|
||||
padding='same',
|
||||
activation=activation,
|
||||
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name=name + 'se_reduce')(
|
||||
se)
|
||||
se = layers.Conv2D(
|
||||
filters,
|
||||
1,
|
||||
padding='same',
|
||||
activation='sigmoid',
|
||||
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name=name + 'se_expand')(
|
||||
se)
|
||||
x = layers.multiply([x, se], name=name + 'se_excite')
|
||||
|
||||
# Output phase
|
||||
x = layers.Conv2D(
|
||||
filters_out,
|
||||
1,
|
||||
padding='same',
|
||||
use_bias=False,
|
||||
kernel_initializer=CONV_KERNEL_INITIALIZER,
|
||||
name=name + 'project_conv')(
|
||||
x)
|
||||
x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
|
||||
if id_skip and strides == 1 and filters_in == filters_out:
|
||||
if drop_rate > 0:
|
||||
x = layers.Dropout(
|
||||
drop_rate, noise_shape=(None, 1, 1, 1), name=name + 'drop')(
|
||||
x)
|
||||
x = layers.add([x, inputs], name=name + 'add')
|
||||
return x
|
||||
|
||||
|
||||
def maybe_restore_with_film(
|
||||
*args,
|
||||
weights='imagenet',
|
||||
include_film=False,
|
||||
**kwargs,
|
||||
):
|
||||
n1 = EfficientNet(*args, weights=weights, include_film=False, **kwargs)
|
||||
if not include_film:
|
||||
return n1
|
||||
# Copy the model weights over to a new model. This is necessary
|
||||
# in case we have inserted early film layers. In this case,
|
||||
# the pretrained weights will fail to restore properly
|
||||
# unless we do this trick.
|
||||
n2 = EfficientNet(*args, weights=None, include_film=True, **kwargs)
|
||||
# The layers without the film layers.
|
||||
l1 = {l.name: l for l in n1.layers}
|
||||
# The layers with the film layers.
|
||||
l2 = {l.name: l for l in n2.layers}
|
||||
for layer_name, layer in l2.items():
|
||||
if layer_name in l1:
|
||||
layer.set_weights(l1[layer_name].get_weights())
|
||||
# Annoyingly, the rescaling and normalization layers get different names
|
||||
# in each graph.
|
||||
elif 'rescaling' in layer_name:
|
||||
_, num = layer_name.split('_')
|
||||
l1_layer_name = 'rescaling_' + str(int(num) - 2 or '')
|
||||
l1_layer_name = l1_layer_name.rstrip('_')
|
||||
layer.set_weights(l1[l1_layer_name].get_weights())
|
||||
elif 'normalization' in layer_name:
|
||||
_, num = layer_name.split('_')
|
||||
l1_layer_name = 'normalization_' + str(int(num) - 1 or '')
|
||||
l1_layer_name = l1_layer_name.rstrip('_')
|
||||
layer.set_weights(l1[l1_layer_name].get_weights())
|
||||
return n2
|
||||
|
||||
|
||||
def EfficientNetB3(include_top=True,
|
||||
weights='imagenet',
|
||||
input_tensor=None,
|
||||
input_shape=None,
|
||||
pooling=None,
|
||||
classes=1000,
|
||||
classifier_activation='softmax',
|
||||
include_film=False,
|
||||
**kwargs):
|
||||
return maybe_restore_with_film(
|
||||
1.2,
|
||||
1.4,
|
||||
300,
|
||||
0.3,
|
||||
model_name='efficientnetb3',
|
||||
include_top=include_top,
|
||||
weights=weights,
|
||||
input_tensor=input_tensor,
|
||||
input_shape=input_shape,
|
||||
pooling=pooling,
|
||||
classes=classes,
|
||||
classifier_activation=classifier_activation,
|
||||
include_film=include_film,
|
||||
**kwargs)
|
||||
|
||||
|
||||
EfficientNetB3.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB3')
|
||||
|
||||
|
||||
def preprocess_input(x, data_format=None): # pylint: disable=unused-argument
|
||||
"""A placeholder method for backward compatibility.
|
||||
|
||||
The preprocessing logic has been included in the efficientnet model
|
||||
implementation. Users are no longer required to call this method to normalize
|
||||
the input data. This method does nothing and only kept as a placeholder to
|
||||
align the API surface between old and new version of model.
|
||||
|
||||
Args:
|
||||
x: A floating point `numpy.array` or a `tf.Tensor`.
|
||||
data_format: Optional data format of the image tensor/array. Defaults to
|
||||
None, in which case the global setting `tf.keras.image_data_format() is
|
||||
used (unless you changed it, it defaults to "channels_last").{mode}
|
||||
|
||||
Returns:
|
||||
Unchanged `numpy.array` or `tf.Tensor`.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def decode_predictions(preds, top=5):
|
||||
global CLASS_INDEX
|
||||
if CLASS_INDEX is None:
|
||||
with open(os.path.join(os.path.dirname(__file__), IMAGENET_JSON_PATH)) as f:
|
||||
CLASS_INDEX = json.load(f)
|
||||
results = []
|
||||
for pred in preds:
|
||||
top_indices = pred.argsort()[-top:][::-1]
|
||||
result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
results.append(result)
|
||||
return results
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests that film_efficientnet can detect an image of a cat."""
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from robotics_transformer.film_efficientnet import film_efficientnet_encoder
|
||||
from skimage import data
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class FilmEfficientnetTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _helper(self, include_film, model_variant):
|
||||
if model_variant == 'b0':
|
||||
size = 224
|
||||
fe = film_efficientnet_encoder.EfficientNetB0
|
||||
elif model_variant == 'b1':
|
||||
size = 240
|
||||
fe = film_efficientnet_encoder.EfficientNetB1
|
||||
elif model_variant == 'b2':
|
||||
size = 260
|
||||
fe = film_efficientnet_encoder.EfficientNetB2
|
||||
elif model_variant == 'b3':
|
||||
size = 300
|
||||
fe = film_efficientnet_encoder.EfficientNetB3
|
||||
elif model_variant == 'b4':
|
||||
size = 380
|
||||
fe = film_efficientnet_encoder.EfficientNetB4
|
||||
elif model_variant == 'b5':
|
||||
size = 456
|
||||
fe = film_efficientnet_encoder.EfficientNetB5
|
||||
elif model_variant == 'b6':
|
||||
size = 528
|
||||
fe = film_efficientnet_encoder.EfficientNetB6
|
||||
elif model_variant == 'b7':
|
||||
size = 600
|
||||
fe = film_efficientnet_encoder.EfficientNetB7
|
||||
else:
|
||||
raise ValueError(f'Unknown variant: {model_variant}')
|
||||
fe = fe(include_top=True, weights='imagenet', include_film=include_film)
|
||||
image = np.expand_dims(data.chelsea(), axis=0)
|
||||
image = tf.image.resize(image, (size, size))
|
||||
context = np.random.randn(1, 512)
|
||||
if include_film:
|
||||
eff_output = fe(
|
||||
(film_efficientnet_encoder.preprocess_input(image), context),
|
||||
training=False)
|
||||
else:
|
||||
eff_output = fe(
|
||||
film_efficientnet_encoder.preprocess_input(image), training=False)
|
||||
film_preds = film_efficientnet_encoder.decode_predictions(
|
||||
eff_output.numpy(), top=10)
|
||||
self.assertIn('tabby', [f[1] for f in film_preds[0]])
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_keras_equivalence_b3(self, include_film):
|
||||
self._helper(include_film, 'b3')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Preprocessing functions for transforming the image for training."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import gin
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
CROP_SIZE = 472
|
||||
|
||||
|
||||
@gin.configurable(
|
||||
denylist=['images', 'crop_size', 'training', 'convert_dtype', 'seed'])
|
||||
def convert_dtype_and_crop_images(images,
|
||||
crop_size: int = CROP_SIZE,
|
||||
training: bool = True,
|
||||
pad_then_crop: bool = False,
|
||||
convert_dtype: bool = True,
|
||||
seed: Optional[tf.Tensor] = None):
|
||||
"""Convert uint8 [512, 640, 3] images to float32 and square crop.
|
||||
|
||||
Args:
|
||||
images: [B, H, W, 3] uint8 tensor of images.
|
||||
crop_size: Width of the square crop.
|
||||
training: If we are in training (random crop) or not-training (fixed crop).
|
||||
pad_then_crop: If True, pads image and then crops the original image size.
|
||||
This allows full field of view to be extracted.
|
||||
convert_dtype: whether or not to convert the image to float32 in the range
|
||||
of (0, 1).
|
||||
seed: Optional seed of shape (2,) for giving to tf.random.stateless_uniform
|
||||
|
||||
Returns:
|
||||
[B, crop_size, crop_size, 3] images of dtype float32.
|
||||
"""
|
||||
|
||||
if seed is None:
|
||||
seed = tf.random.uniform(shape=(2,), maxval=2**30, dtype=tf.int32)
|
||||
|
||||
seed2 = tf.random.experimental.stateless_split(seed, num=1)[0]
|
||||
|
||||
if convert_dtype:
|
||||
images = tf.image.convert_image_dtype(images, tf.float32)
|
||||
image_height = images.get_shape().as_list()[-3]
|
||||
image_width = images.get_shape().as_list()[-2]
|
||||
|
||||
if pad_then_crop:
|
||||
|
||||
if training:
|
||||
if image_height == 512:
|
||||
ud_pad = 40
|
||||
lr_pad = 100
|
||||
elif image_height == 256:
|
||||
ud_pad = 20
|
||||
lr_pad = 50
|
||||
else:
|
||||
raise ValueError(
|
||||
'convert_dtype_and_crop_images only supports image height 512 or '
|
||||
'256.')
|
||||
max_y = 2 * ud_pad
|
||||
max_x = 2 * lr_pad
|
||||
images = tf.image.pad_to_bounding_box(
|
||||
images,
|
||||
offset_height=ud_pad,
|
||||
offset_width=lr_pad,
|
||||
target_height=image_height + 2 * ud_pad,
|
||||
target_width=image_width + 2 * lr_pad)
|
||||
offset_y = tf.random.stateless_uniform((),
|
||||
maxval=max_y + 1,
|
||||
dtype=tf.int32,
|
||||
seed=seed)
|
||||
offset_x = tf.random.stateless_uniform((),
|
||||
maxval=max_x + 1,
|
||||
dtype=tf.int32,
|
||||
seed=seed2)
|
||||
images = tf.image.crop_to_bounding_box(images, offset_y, offset_x,
|
||||
image_height, image_width)
|
||||
else:
|
||||
# Standard cropping.
|
||||
max_y = image_height - crop_size
|
||||
max_x = image_width - crop_size
|
||||
|
||||
if training:
|
||||
offset_y = tf.random.stateless_uniform((),
|
||||
maxval=max_y + 1,
|
||||
dtype=tf.int32,
|
||||
seed=seed)
|
||||
offset_x = tf.random.stateless_uniform((),
|
||||
maxval=max_x + 1,
|
||||
dtype=tf.int32,
|
||||
seed=seed2)
|
||||
images = tf.image.crop_to_bounding_box(images, offset_y, offset_x,
|
||||
crop_size, crop_size)
|
||||
else:
|
||||
images = tf.image.crop_to_bounding_box(images, max_y // 2, max_x // 2,
|
||||
crop_size, crop_size)
|
||||
return images
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for preprocessors."""
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from robotics_transformer.film_efficientnet import preprocessors
|
||||
from tensor2robot.utils import tensorspec_utils
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
def _random_image(shape):
|
||||
images = tf.random.uniform(
|
||||
shape, minval=0, maxval=255, dtype=tf.dtypes.int32, seed=42)
|
||||
return tf.cast(images, tf.uint8)
|
||||
|
||||
|
||||
def _get_features(
|
||||
image_shape=(2, 512, 640, 3), use_task_image=False, use_goal_image=False):
|
||||
# Time-dimension stacking occurs during training but not eval.
|
||||
state = tensorspec_utils.TensorSpecStruct(image=_random_image(image_shape))
|
||||
if use_task_image:
|
||||
state.task_image = _random_image(image_shape)
|
||||
if use_goal_image:
|
||||
state.goal_image = _random_image(image_shape)
|
||||
return state
|
||||
|
||||
|
||||
class PreprocessorsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters((True, False, False), (False, True, False),
|
||||
(True, False, True), (False, True, True))
|
||||
def testConvertDtypeAndCropImages(self, training, pad_then_crop,
|
||||
convert_dtype):
|
||||
features = _get_features()
|
||||
images = preprocessors.convert_dtype_and_crop_images(
|
||||
features.image,
|
||||
training=training,
|
||||
pad_then_crop=pad_then_crop,
|
||||
convert_dtype=convert_dtype)
|
||||
expected_cropped_shape = ([2, 512, 640, 3]
|
||||
if pad_then_crop else [2, 472, 472, 3])
|
||||
tf.ensure_shape(images, expected_cropped_shape)
|
||||
if convert_dtype:
|
||||
self.assertEqual(images.dtype, tf.float32)
|
||||
self.assertLessEqual(images.numpy().max(), 1.)
|
||||
self.assertGreaterEqual(images.numpy().min(), 0.)
|
||||
else:
|
||||
self.assertEqual(images.dtype, tf.uint8)
|
||||
self.assertLessEqual(images.numpy().max(), 255)
|
||||
self.assertGreaterEqual(images.numpy().min(), 0)
|
||||
self.assertGreater(images.numpy().max(), 1)
|
||||
|
||||
def testConvertDtypeAndCropImagesSeeded(self):
|
||||
features = _get_features()
|
||||
seed = tf.constant([1, 2], tf.int32)
|
||||
images1 = preprocessors.convert_dtype_and_crop_images(
|
||||
features.image, training=True, pad_then_crop=True, seed=seed)
|
||||
images2 = preprocessors.convert_dtype_and_crop_images(
|
||||
features.image, training=True, pad_then_crop=True, seed=seed)
|
||||
diff = np.sum(np.abs(images1.numpy() - images2.numpy()))
|
||||
self.assertAlmostEqual(diff, 0)
|
||||
|
||||
def testConvertDtypeAndCropImagesUnseeded(self):
|
||||
features = _get_features()
|
||||
seed1 = tf.constant([1, 2], tf.int32)
|
||||
images1 = preprocessors.convert_dtype_and_crop_images(
|
||||
features.image, training=True, pad_then_crop=True, seed=seed1)
|
||||
seed2 = tf.constant([2, 3], tf.int32)
|
||||
images2 = preprocessors.convert_dtype_and_crop_images(
|
||||
features.image, training=True, pad_then_crop=True, seed=seed2)
|
||||
diff = np.sum(np.abs(images1.numpy() - images2.numpy()))
|
||||
self.assertNotAlmostEqual(diff, 0)
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Encoder based on Efficientnet."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import gin
|
||||
from robotics_transformer.film_efficientnet import film_conditioning_layer
|
||||
from robotics_transformer.film_efficientnet import film_efficientnet_encoder
|
||||
import tensorflow as tf
|
||||
|
||||
_MODELS = {
|
||||
'b3': film_efficientnet_encoder.EfficientNetB3,
|
||||
}
|
||||
|
||||
_SIZES = {
|
||||
'b3': 300,
|
||||
}
|
||||
|
||||
|
||||
@gin.configurable
|
||||
class EfficientNetEncoder(tf.keras.layers.Layer):
|
||||
"""Applies a pretrained Efficientnet based encoder."""
|
||||
|
||||
def __init__(self,
|
||||
model_variant: str = 'b3',
|
||||
freeze: bool = False,
|
||||
early_film: bool = True,
|
||||
weights: Optional[str] = 'imagenet',
|
||||
include_top: bool = False,
|
||||
pooling: bool = True,
|
||||
**kwargs):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
model_variant: One of 'b0-b7' of the efficient encoders. See
|
||||
https://arxiv.org/abs/1905.11946 to understand the variants.
|
||||
freeze: Whether or not to freeze the pretrained weights (seems to not work
|
||||
well).
|
||||
early_film: Whether to inject film layers into the efficientnet encoder
|
||||
(seems to be essential to getting strong performance).
|
||||
weights: Which pretrained weights to use. Either 'imagenet', a path to the
|
||||
pretrained weights, or None for from scratch.
|
||||
include_top: Whether to add the top fully connected layer. If True, this
|
||||
will cause encoding to fail and is used only for unit testing purposes.
|
||||
pooling: If false, returns feature map before global average pooling
|
||||
**kwargs: Keras specific layer kwargs.
|
||||
"""
|
||||
super(EfficientNetEncoder, self).__init__(**kwargs)
|
||||
if model_variant not in _MODELS:
|
||||
raise ValueError(f'Unknown variant {model_variant}')
|
||||
self.model_variant = model_variant
|
||||
self.early_film = early_film
|
||||
self.freeze = freeze
|
||||
self.conv1x1 = tf.keras.layers.Conv2D(
|
||||
filters=512,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding='SAME',
|
||||
use_bias=False,
|
||||
kernel_initializer=tf.keras.initializers.VarianceScaling())
|
||||
self.net = _MODELS[model_variant](
|
||||
include_top=include_top,
|
||||
weights=weights,
|
||||
include_film=early_film,
|
||||
)
|
||||
self.film_layer = film_conditioning_layer.FilmConditioning(num_channels=512)
|
||||
self._pooling = pooling
|
||||
|
||||
def _prepare_image(self, image: tf.Tensor) -> tf.Tensor:
|
||||
"""Resize the input image and check that the range is correct."""
|
||||
if len(image.shape) != 4 or image.shape[-1] != 3:
|
||||
raise ValueError('Provided image should have shape (b, h, w, 3).')
|
||||
size = _SIZES[self.model_variant]
|
||||
if image.shape[1] < size / 4 or image.shape[2] < size / 4:
|
||||
raise ValueError('Provided image is too small.')
|
||||
if image.shape[1] > size * 4 or image.shape[2] > size * 4:
|
||||
raise ValueError('Provided image is too large.')
|
||||
image = tf.image.resize(image, (size, size))
|
||||
c1 = tf.Assert(tf.reduce_max(image) <= 1, data=[tf.reduce_max(image)])
|
||||
c2 = tf.Assert(tf.reduce_min(image) >= 0, data=[tf.reduce_min(image)])
|
||||
with tf.control_dependencies([c1, c2]):
|
||||
image *= 255 # The image is expected to be in range(0, 255).
|
||||
image = film_efficientnet_encoder.preprocess_input(image)
|
||||
return image
|
||||
|
||||
def _encode(self, image: tf.Tensor, context: tf.Tensor,
|
||||
training: bool) -> tf.Tensor:
|
||||
"""Run the image through the efficientnet encoder."""
|
||||
image = self._prepare_image(image)
|
||||
if self.early_film:
|
||||
return self.net((image, context), training=training)
|
||||
return self.net(image, training=training)
|
||||
|
||||
def call(self,
|
||||
image: tf.Tensor,
|
||||
context: Optional[tf.Tensor] = None,
|
||||
training: bool = True) -> tf.Tensor:
|
||||
if self.freeze:
|
||||
features = tf.stop_gradient(self._encode(image, context, training))
|
||||
else:
|
||||
features = self._encode(image, context, training)
|
||||
if context is not None:
|
||||
features = self.conv1x1(features)
|
||||
features = self.film_layer(features, context)
|
||||
|
||||
if not self._pooling:
|
||||
return features
|
||||
|
||||
# Global average pool.
|
||||
return tf.reduce_mean(features, [1, 2])
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for pretrained_efficientnet_encoder."""
|
||||
|
||||
import numpy as np
|
||||
from robotics_transformer.film_efficientnet import film_efficientnet_encoder
|
||||
from robotics_transformer.film_efficientnet import pretrained_efficientnet_encoder as eff
|
||||
from skimage import data
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PretrainedEfficientnetEncoderTest(tf.test.TestCase):
|
||||
|
||||
def test_encoding(self):
|
||||
"""Test that we get a correctly shaped decoding."""
|
||||
state = np.random.RandomState(0)
|
||||
context = state.uniform(-1, 1, (10, 512))
|
||||
model = eff.EfficientNetEncoder()
|
||||
image = np.expand_dims(data.chelsea(), axis=0) / 255
|
||||
preds = model(image, context, training=False).numpy()
|
||||
self.assertEqual(preds.shape, (10, 512))
|
||||
|
||||
def test_imagenet_classification(self):
|
||||
"""Test that we can correctly classify an image of a cat."""
|
||||
state = np.random.RandomState(0)
|
||||
context = state.uniform(-1, 1, (10, 512))
|
||||
model = eff.EfficientNetEncoder(include_top=True)
|
||||
image = np.expand_dims(data.chelsea(), axis=0) / 255
|
||||
preds = model._encode(image, context, training=False).numpy()
|
||||
predicted_names = [
|
||||
n[1]
|
||||
for n in film_efficientnet_encoder.decode_predictions(preds, top=3)[0]
|
||||
]
|
||||
self.assertIn('tabby', predicted_names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,44 @@
|
|||
[project]
|
||||
name = "robotics_transformer"
|
||||
description = ""
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.7"
|
||||
license = {file = "LICENSE"}
|
||||
authors = [{name = "robotics_transformer authors", email="robotics_transformer@google.com"}]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Intended Audience :: Science/Research",
|
||||
]
|
||||
keywords = []
|
||||
|
||||
# pip dependencies of the project
|
||||
dependencies = []
|
||||
|
||||
# This is set automatically by flit using `robotics_transformer.__version__`
|
||||
dynamic = ["version"]
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://github.com/google-research/robotics_transformer"
|
||||
repository = "https://github.com/google-research/robotics_transformer"
|
||||
# Other: `documentation`, `changelog`
|
||||
|
||||
[project.optional-dependencies]
|
||||
# Development deps (unittest, linting, formating,...)
|
||||
# Installed through `pip install .[dev]`
|
||||
dev = [
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
"pylint>=2.6.0",
|
||||
"pyink",
|
||||
]
|
||||
|
||||
[tool.pyink]
|
||||
# Formatting configuration to follow Google style-guide
|
||||
pyink-indentation = 2
|
||||
pyink-use-majority-quotes = true
|
||||
|
||||
[build-system]
|
||||
requires = ["flit_core >=3.5,<4"]
|
||||
build-backend = "flit_core.buildapi"
|
|
@ -0,0 +1,9 @@
|
|||
absl-py>=0.5.0
|
||||
numpy>=1.13.3
|
||||
tensorflow>=1.13.0
|
||||
tensorflow-serving-api>=1.13.0
|
||||
gin-config>=0.1.4
|
||||
tensorflow-probability>=0.6.0
|
||||
tf-agents>=0.3.0
|
||||
tf-slim>=1.0
|
||||
git+https://github.com/google-research/tensor2robot#tensor2robot
|
|
@ -0,0 +1,19 @@
|
|||
include "devtools/blueprint/bluze/public/bluze.ncl";
|
||||
include bytes "third_party/py/robotics_transformer/bluze.textproto" as textproto;
|
||||
|
||||
// See go/bluze/guide before editing. To check the generated final blueprint run
|
||||
// rncl third_party/py/robotics_transformer/robotics_transformer.blueprint printproto blueprint_file
|
||||
|
||||
blueprint_file = ::bluze::BlueprintFile(
|
||||
textproto,
|
||||
|
||||
project_name = "robotics_transformer",
|
||||
teams_product_id = 9019942154,
|
||||
tech_lead = ["keerthanapg"],
|
||||
dev_mailing_list = "robotics_transformer-automated@google.com",
|
||||
mdb_groups = ["robotics"],
|
||||
buganizer_component_ids = [1150225],
|
||||
metadata_path = "//depot/google3/third_party/py/robotics_transformer/METADATA",
|
||||
|
||||
// Customize your blueprint here: go/blueprint/howto-write.
|
||||
);
|
|
@ -0,0 +1,173 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Sequence policy and agent that directly output actions via actor network.
|
||||
|
||||
These classes are not intended to change as they are generic enough for any
|
||||
all-neural actor based agent+policy. All new features are intended to be
|
||||
implemented in `actor_network` and `loss_fn`.
|
||||
|
||||
# TODO(b/231896343): Update litred docs on how to use these.
|
||||
"""
|
||||
from typing import Optional, Type
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
from tf_agents.agents import data_converter
|
||||
from tf_agents.agents import tf_agent
|
||||
from tf_agents.networks import network
|
||||
from tf_agents.policies import actor_policy
|
||||
from tf_agents.trajectories import policy_step
|
||||
from tf_agents.trajectories import time_step as ts
|
||||
from tf_agents.typing import types
|
||||
from tf_agents.utils import nest_utils
|
||||
|
||||
|
||||
class SequencePolicy(actor_policy.ActorPolicy):
|
||||
"""A policy that directly outputs actions via an actor network."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._actions = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_actions(self, actions):
|
||||
self._actor_network.set_actions(actions)
|
||||
|
||||
def get_actor_loss(self):
|
||||
return self._actor_network.get_actor_loss()
|
||||
|
||||
def get_aux_info(self):
|
||||
return self._actor_network.get_aux_info()
|
||||
|
||||
def set_training(self, training):
|
||||
self._training = training
|
||||
|
||||
def _action(self,
|
||||
time_step: ts.TimeStep,
|
||||
policy_state: types.NestedTensor,
|
||||
seed: Optional[types.Seed] = None) -> policy_step.PolicyStep:
|
||||
del seed
|
||||
action, policy_state = self._apply_actor_network(
|
||||
time_step.observation,
|
||||
step_type=time_step.step_type,
|
||||
policy_state=policy_state)
|
||||
info = ()
|
||||
return policy_step.PolicyStep(action, policy_state, info)
|
||||
|
||||
def _distribution(self, time_step, policy_state):
|
||||
current_step = super()._distribution(time_step, policy_state)
|
||||
return current_step
|
||||
|
||||
|
||||
class SequenceAgent(tf_agent.TFAgent):
|
||||
"""A sequence agent that directly outputs actions via an actor network."""
|
||||
|
||||
def __init__(self,
|
||||
time_step_spec: ts.TimeStep,
|
||||
action_spec: types.NestedTensorSpec,
|
||||
actor_network: Type[network.Network],
|
||||
actor_optimizer: tf.keras.optimizers.Optimizer,
|
||||
policy_cls: Type[actor_policy.ActorPolicy] = SequencePolicy,
|
||||
time_sequence_length: int = 6,
|
||||
debug_summaries: bool = False,
|
||||
**kwargs):
|
||||
self._info_spec = ()
|
||||
self._actor_network = actor_network( # pytype: disable=missing-parameter # dynamic-method-lookup
|
||||
input_tensor_spec=time_step_spec.observation,
|
||||
output_tensor_spec=action_spec,
|
||||
policy_info_spec=self._info_spec,
|
||||
train_step_counter=kwargs['train_step_counter'],
|
||||
time_sequence_length=time_sequence_length)
|
||||
|
||||
self._actor_optimizer = actor_optimizer
|
||||
# Train policy is only used for loss and never exported as saved_model.
|
||||
self._train_policy = policy_cls(
|
||||
time_step_spec=time_step_spec,
|
||||
action_spec=action_spec,
|
||||
info_spec=self._info_spec,
|
||||
actor_network=self._actor_network,
|
||||
training=True)
|
||||
collect_policy = policy_cls(
|
||||
time_step_spec=time_step_spec,
|
||||
action_spec=action_spec,
|
||||
info_spec=self._info_spec,
|
||||
actor_network=self._actor_network,
|
||||
training=False)
|
||||
super(SequenceAgent, self).__init__(
|
||||
time_step_spec,
|
||||
action_spec,
|
||||
collect_policy, # We use the collect_policy as the eval policy.
|
||||
collect_policy,
|
||||
train_sequence_length=time_sequence_length,
|
||||
**kwargs)
|
||||
self._data_context = data_converter.DataContext(
|
||||
time_step_spec=time_step_spec,
|
||||
action_spec=action_spec,
|
||||
info_spec=collect_policy.info_spec,
|
||||
use_half_transition=True)
|
||||
self.as_transition = data_converter.AsHalfTransition(
|
||||
self._data_context, squeeze_time_dim=False)
|
||||
self._debug_summaries = debug_summaries
|
||||
|
||||
num_params = 0
|
||||
for weight in self._actor_network.trainable_weights:
|
||||
weight_params = 1
|
||||
for dim in weight.shape:
|
||||
weight_params *= dim
|
||||
logging.info('%s has %s params.', weight.name, weight_params)
|
||||
num_params += weight_params
|
||||
logging.info('Actor network has %sM params.', round(num_params / 1000000.,
|
||||
2))
|
||||
|
||||
def _train(self, experience: types.NestedTensor,
|
||||
weights: types.Tensor) -> tf_agent.LossInfo:
|
||||
self.train_step_counter.assign_add(1)
|
||||
loss_info = self._loss(experience, weights, training=True)
|
||||
self._apply_gradients(loss_info.loss)
|
||||
return loss_info
|
||||
|
||||
def _apply_gradients(self, loss: types.Tensor):
|
||||
variables = self._actor_network.trainable_weights
|
||||
gradients = tf.gradients(loss, variables)
|
||||
# Skip nan and inf gradients.
|
||||
new_gradients = []
|
||||
for g in gradients:
|
||||
if g is not None:
|
||||
new_g = tf.where(
|
||||
tf.math.logical_or(tf.math.is_inf(g), tf.math.is_nan(g)),
|
||||
tf.zeros_like(g), g)
|
||||
new_gradients.append(new_g)
|
||||
else:
|
||||
new_gradients.append(g)
|
||||
grads_and_vars = list(zip(new_gradients, variables))
|
||||
self._actor_optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
def _loss(self, experience: types.NestedTensor, weights: types.Tensor,
|
||||
training: bool) -> tf_agent.LossInfo:
|
||||
transition = self.as_transition(experience)
|
||||
time_steps, policy_steps, _ = transition
|
||||
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
|
||||
policy = self._train_policy
|
||||
policy.set_actions(policy_steps.action)
|
||||
policy.set_training(training=training)
|
||||
with tf.name_scope('actor_loss'):
|
||||
policy_state = policy.get_initial_state(batch_size)
|
||||
policy.action(time_steps, policy_state=policy_state)
|
||||
valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
|
||||
loss = valid_mask * policy.get_actor_loss()
|
||||
loss = tf.reduce_mean(loss)
|
||||
policy.set_actions(None)
|
||||
self._actor_network.add_summaries(time_steps.observation,
|
||||
policy.get_aux_info(),
|
||||
self._debug_summaries, training)
|
||||
return tf_agent.LossInfo(loss=loss, extra=loss)
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for sequence_agent."""
|
||||
from robotics_transformer.sequence_agent_test_set_up import SequenceAgentTestSetUp
|
||||
import tensorflow as tf
|
||||
from tf_agents.agents import data_converter
|
||||
|
||||
|
||||
class SequenceAgentTest(SequenceAgentTestSetUp):
|
||||
|
||||
def testAsTransitionType(self):
|
||||
agent = self.create_agent_and_initialize()
|
||||
self.assertIsInstance(agent.as_transition, data_converter.AsHalfTransition)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for sequence_agent."""
|
||||
from typing import Type
|
||||
|
||||
import numpy as np
|
||||
from robotics_transformer import sequence_agent
|
||||
from tensor2robot.utils import tensorspec_utils
|
||||
import tensorflow as tf
|
||||
from tf_agents.networks import network
|
||||
from tf_agents.policies import policy_saver
|
||||
from tf_agents.specs import tensor_spec
|
||||
from tf_agents.trajectories import time_step as ts
|
||||
|
||||
|
||||
class DummyActorNet(network.Network):
|
||||
"""Used for testing SequenceAgent and its subclass."""
|
||||
|
||||
def __init__(self,
|
||||
output_tensor_spec=None,
|
||||
train_step_counter=None,
|
||||
policy_info_spec=None,
|
||||
time_sequence_length=1,
|
||||
use_tcl=False,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def tokens_per_action(self):
|
||||
return 8
|
||||
|
||||
def set_actions(self, actions):
|
||||
self._actions = actions
|
||||
|
||||
def get_actor_loss(self):
|
||||
return self._actor_loss
|
||||
|
||||
def call(self,
|
||||
observations,
|
||||
step_type,
|
||||
network_state,
|
||||
actions=None,
|
||||
training=False):
|
||||
del step_type
|
||||
image = observations['image']
|
||||
tf.expand_dims(tf.reduce_mean(image, axis=-1), -1)
|
||||
actions = tensorspec_utils.TensorSpecStruct(
|
||||
world_vector=tf.constant(1., shape=[1, 3]),
|
||||
rotation_delta=tf.constant(1., shape=[1, 3]),
|
||||
terminate_episode=tf.constant(1, shape=[1, 2]),
|
||||
gripper_closedness_action=tf.constant(1., shape=[1, 1]),
|
||||
)
|
||||
return actions, network_state
|
||||
|
||||
@property
|
||||
def trainable_weights(self):
|
||||
return [tf.Variable(1.0)]
|
||||
|
||||
|
||||
class SequenceAgentTestSetUp(tf.test.TestCase):
|
||||
"""Defines spec for testing SequenceAgent and its subclass, tests create."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
self._action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
|
||||
|
||||
self._action_spec.rotation_delta = tensor_spec.BoundedTensorSpec(
|
||||
(3,),
|
||||
dtype=tf.float32,
|
||||
minimum=-np.pi / 2,
|
||||
maximum=np.pi / 2,
|
||||
name='rotation_delta')
|
||||
|
||||
self._action_spec.gripper_closedness_action = tensor_spec.BoundedTensorSpec(
|
||||
(1,),
|
||||
dtype=tf.float32,
|
||||
minimum=-1.,
|
||||
maximum=1.,
|
||||
name='gripper_closedness_action')
|
||||
self._action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
|
||||
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
|
||||
|
||||
state_spec = tensorspec_utils.TensorSpecStruct()
|
||||
state_spec.image = tensor_spec.BoundedTensorSpec([256, 320, 3],
|
||||
dtype=tf.float32,
|
||||
name='image',
|
||||
minimum=0.,
|
||||
maximum=1.)
|
||||
state_spec.natural_language_embedding = tensor_spec.TensorSpec(
|
||||
shape=[512], dtype=tf.float32, name='natural_language_embedding')
|
||||
self._time_step_spec = ts.time_step_spec(observation_spec=state_spec)
|
||||
|
||||
self.sequence_agent_cls = sequence_agent.SequenceAgent
|
||||
|
||||
def create_agent_and_initialize(self,
|
||||
actor_network: Type[
|
||||
network.Network] = DummyActorNet,
|
||||
**kwargs):
|
||||
"""Creates the agent and initialize it."""
|
||||
agent = self.sequence_agent_cls(
|
||||
time_step_spec=self._time_step_spec,
|
||||
action_spec=self._action_spec,
|
||||
actor_network=actor_network,
|
||||
actor_optimizer=tf.keras.optimizers.Adam(),
|
||||
train_step_counter=tf.compat.v1.train.get_or_create_global_step(),
|
||||
**kwargs)
|
||||
agent.initialize()
|
||||
return agent
|
||||
|
||||
def testCreateAgent(self):
|
||||
"""Creates the Agent and save the agent.policy."""
|
||||
agent = self.create_agent_and_initialize()
|
||||
self.assertIsNotNone(agent.policy)
|
||||
|
||||
policy_model_saver = policy_saver.PolicySaver(
|
||||
agent.policy,
|
||||
train_step=tf.compat.v2.Variable(
|
||||
0,
|
||||
trainable=False,
|
||||
dtype=tf.int64,
|
||||
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||
shape=()),
|
||||
input_fn_and_spec=None)
|
||||
save_options = tf.saved_model.SaveOptions(
|
||||
experimental_io_device='/job:localhost',
|
||||
experimental_custom_gradients=False)
|
||||
policy_model_saver.save('/tmp/unittest/policy/0', options=save_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A simple action tokenizer used with Robotics Transformer 1.
|
||||
|
||||
As an example, if an action is:
|
||||
terminate = [0, 1]
|
||||
world_vector = [0.9, 0.8, -0.3]
|
||||
rotation_delta = [-0.1, 0.2, .6]
|
||||
gripper_closedness = 0.9
|
||||
|
||||
Then we build a sequence of tokens of length 8 [one for each dimension].
|
||||
The int32 type action dimensions are already assumed discrete and tokenized,
|
||||
the float dimensions are bucketed according to the specs min and max. Each
|
||||
dimension has 'vocab_size' buckets.
|
||||
|
||||
Currently, this tokenizer assumes one action spec and it is highly recommended
|
||||
to specify the 'action_order', eg [terminate, world_vector, rotation_delta,
|
||||
gripper_closedness]. Since after tokenization you lose that information, this
|
||||
will be useful for debugging. Actions may also be subselected for prediction,
|
||||
since not all actions are needed in the action_order.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from tensor2robot.utils import tensorspec_utils
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class RT1ActionTokenizer:
|
||||
"""Tokenizes based on vocab size."""
|
||||
|
||||
def __init__(self,
|
||||
action_spec: tensorspec_utils.TensorSpecStruct,
|
||||
vocab_size: int,
|
||||
action_order: Optional[list[str]] = None):
|
||||
"""Instantiates an RT1ActionTokenizer.
|
||||
|
||||
Args:
|
||||
action_spec: Tensor spec of the expected action tensor.
|
||||
vocab_size: Number of buckets to discretize action to.
|
||||
action_order: Order of the action names, used to discern the order of
|
||||
tokenized actions to detokenize and assemble back to action tensor
|
||||
"""
|
||||
self._action_spec = action_spec
|
||||
self._vocab_size = vocab_size
|
||||
if action_order is None:
|
||||
self._action_order = self._action_spec.keys()
|
||||
else:
|
||||
for action in action_order:
|
||||
if action not in self._action_spec.keys():
|
||||
raise ValueError('actions: %s not found in action_spec: %s' %
|
||||
(action, action_spec.keys()))
|
||||
assert action in self._action_spec.keys()
|
||||
self._action_order = action_order
|
||||
self._tokens_per_action = 0
|
||||
for action in self._action_order:
|
||||
action_shape = self._action_spec[action].shape
|
||||
if len(action_shape) != 1:
|
||||
raise ValueError(
|
||||
'Only action shapes with single dimension supported, got %s' %
|
||||
action_shape)
|
||||
if self._action_spec[action].dtype == tf.int32:
|
||||
# Int32 actions are already assumed to be tokens.
|
||||
self._tokens_per_action += 1
|
||||
else:
|
||||
self._tokens_per_action += action_shape[0]
|
||||
|
||||
# We measure # of action tokens in two different way. One is by checking
|
||||
# from action_order (above) and the other is by looping through the
|
||||
# action spec (below). We aseert the # of action tokens are the same
|
||||
# calculated by these two ways. This will assure action_order is correctly
|
||||
# configured, otherwise, it will through an error in the assert.
|
||||
num_action_token = 0
|
||||
for spec in self._action_spec.values():
|
||||
if spec.dtype == tf.int32:
|
||||
num_action_token += 1
|
||||
else:
|
||||
num_action_token += spec.shape[-1]
|
||||
tf.debugging.assert_equal(num_action_token, self._tokens_per_action)
|
||||
|
||||
@property
|
||||
def tokens_per_action(self) -> int:
|
||||
return self._tokens_per_action
|
||||
|
||||
@property
|
||||
def action_spec(self) -> tensorspec_utils.TensorSpecStruct:
|
||||
return self._action_spec
|
||||
|
||||
@property
|
||||
def action_order(self) -> list[str]:
|
||||
return self._action_order
|
||||
|
||||
def tokenize(self, action: tensorspec_utils.TensorSpecStruct) -> tf.Tensor:
|
||||
"""Tokenizes an action."""
|
||||
action_tokens = []
|
||||
for k in self._action_order:
|
||||
a = action[k] # a is [batch, actions_size]
|
||||
spec = self._action_spec[k]
|
||||
if spec.dtype == tf.int32:
|
||||
# Int32 actions are already assumed to be tokens, assume it is smaller
|
||||
# than the vocab size, so all we need to do is pad zeros.
|
||||
tf.debugging.assert_equal(1, tf.reduce_sum(a, axis=-1))
|
||||
# extract the token [batch, 1]
|
||||
token = tf.argmax(a, axis=-1, output_type=tf.int32)
|
||||
tf.debugging.assert_less(token, self._vocab_size)
|
||||
# Add a seq dimension [batch, 1]
|
||||
token = tf.expand_dims(token, axis=-1)
|
||||
else:
|
||||
a = tf.clip_by_value(a, spec.minimum, spec.maximum)
|
||||
# Normalize the action [batch, actions_size]
|
||||
token = (a - spec.minimum) / (spec.maximum - spec.minimum)
|
||||
# Bucket and discretize the action to vocab_size, [batch, actions_size]
|
||||
token = tf.cast(token * (self._vocab_size - 1), tf.int32)
|
||||
action_tokens.append(token)
|
||||
# Append all actions, [batch, all_actions_size]
|
||||
action_tokens = tf.concat(action_tokens, axis=-1)
|
||||
return action_tokens
|
||||
|
||||
def detokenize(self,
|
||||
action_tokens: tf.Tensor) -> tensorspec_utils.TensorSpecStruct:
|
||||
"""Detokenizes an action."""
|
||||
action = tensorspec_utils.TensorSpecStruct()
|
||||
token_index = 0
|
||||
for k in self._action_order:
|
||||
spec = self._action_spec[k]
|
||||
action_dim = spec.shape[0]
|
||||
if spec.dtype == tf.int32:
|
||||
# Int32 actions are already assumed to be tokens.
|
||||
action[k] = action_tokens[..., token_index]
|
||||
# A poor model may output tokens outside the allowed range, in that case
|
||||
# set them to a default value, the 0 token in this case.
|
||||
outside_range = tf.greater_equal(action[k], action_dim)
|
||||
action[k] = tf.where(outside_range, tf.zeros_like(action[k]), action[k])
|
||||
action[k] = tf.one_hot(
|
||||
action[k], depth=action_dim, axis=-1, dtype=tf.int32)
|
||||
token_index += 1
|
||||
else:
|
||||
actions = []
|
||||
for _ in range(action_dim):
|
||||
a = action_tokens[..., token_index:token_index + 1]
|
||||
a = tf.cast(a, tf.float32)
|
||||
a = a / (self._vocab_size - 1)
|
||||
a = (a * (spec.maximum - spec.minimum)) + spec.minimum
|
||||
actions.append(a)
|
||||
token_index += 1
|
||||
action[k] = tf.concat(actions, axis=-1)
|
||||
return action
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for action_tokenizer."""
|
||||
import numpy as np
|
||||
from robotics_transformer.tokenizers import action_tokenizer
|
||||
from tensor2robot.utils import tensorspec_utils
|
||||
import tensorflow as tf
|
||||
from tf_agents.specs import tensor_spec
|
||||
|
||||
|
||||
class ActionTokenizerTest(tf.test.TestCase):
|
||||
|
||||
def testTokenize_int32(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
|
||||
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(action_spec, vocab_size=10)
|
||||
self.assertEqual(1, tokenizer.tokens_per_action)
|
||||
action = tensorspec_utils.TensorSpecStruct(terminate_episode=[0, 1])
|
||||
action_tokens = tokenizer.tokenize(action)
|
||||
self.assertEqual([1], action_tokens.numpy())
|
||||
|
||||
def testTokenize_int32_not_one_hot(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
|
||||
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(action_spec, vocab_size=10)
|
||||
self.assertEqual(1, tokenizer.tokens_per_action)
|
||||
action = tensorspec_utils.TensorSpecStruct(terminate_episode=[1, 8])
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
tokenizer.tokenize(action)
|
||||
|
||||
def testDetokenize_int32(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
|
||||
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(action_spec, vocab_size=10)
|
||||
# 0 token should become a one hot: [1, 0]
|
||||
action = tokenizer.detokenize(tf.constant([0], dtype=tf.int32))
|
||||
self.assertSequenceEqual([1, 0], list(action['terminate_episode'].numpy()))
|
||||
# 1 token should become a one hot: [0, 1]
|
||||
action = tokenizer.detokenize(tf.constant([1], dtype=tf.int32))
|
||||
self.assertSequenceEqual([0, 1], list(action['terminate_episode'].numpy()))
|
||||
# OOV 3 token should become a default one hot: [1, 0]
|
||||
action = tokenizer.detokenize(tf.constant([3], dtype=tf.int32))
|
||||
self.assertSequenceEqual([1, 0], list(action['terminate_episode'].numpy()))
|
||||
|
||||
def testTokenize_float(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(action_spec, vocab_size=10)
|
||||
self.assertEqual(3, tokenizer.tokens_per_action)
|
||||
action = tensorspec_utils.TensorSpecStruct(world_vector=[0.1, 0.5, -0.8])
|
||||
action_tokens = tokenizer.tokenize(action)
|
||||
self.assertSequenceEqual([4, 6, 0], list(action_tokens.numpy()))
|
||||
|
||||
def testTokenize_float_with_time_dimension(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(action_spec, vocab_size=10)
|
||||
self.assertEqual(3, tokenizer.tokens_per_action)
|
||||
batch_size = 2
|
||||
time_dimension = 3
|
||||
action = tensorspec_utils.TensorSpecStruct(
|
||||
world_vector=tf.constant(
|
||||
[[0.1, 0.5, -0.8], [0.1, 0.5, -0.8], [0.1, 0.5, -0.8],
|
||||
[0.1, 0.5, -0.8], [0.1, 0.5, -0.8], [0.1, 0.5, -0.8]],
|
||||
shape=[batch_size, time_dimension, tokenizer.tokens_per_action]))
|
||||
action_tokens = tokenizer.tokenize(action)
|
||||
self.assertSequenceEqual(
|
||||
[batch_size, time_dimension, tokenizer.tokens_per_action],
|
||||
action_tokens.shape.as_list())
|
||||
|
||||
def testTokenize_float_at_limits(self):
|
||||
minimum = -1.
|
||||
maximum = 1.
|
||||
vocab_size = 10
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(2,),
|
||||
dtype=tf.float32,
|
||||
minimum=minimum,
|
||||
maximum=maximum,
|
||||
name='world_vector')
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(
|
||||
action_spec, vocab_size=vocab_size)
|
||||
self.assertEqual(2, tokenizer.tokens_per_action)
|
||||
action = tensorspec_utils.TensorSpecStruct(world_vector=[minimum, maximum])
|
||||
action_tokens = tokenizer.tokenize(action)
|
||||
# Minimum value will go to 0
|
||||
# Maximum value witll go to vocab_size-1
|
||||
self.assertSequenceEqual([0, vocab_size - 1], list(action_tokens.numpy()))
|
||||
|
||||
def testTokenize_invalid_action_spec_shape(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(2, 2), dtype=tf.float32, minimum=1, maximum=-1, name='world_vector')
|
||||
with self.assertRaises(ValueError):
|
||||
action_tokenizer.RT1ActionTokenizer(action_spec, vocab_size=10)
|
||||
|
||||
def testTokenizeAndDetokenizeIsEqual(self):
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
|
||||
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
|
||||
|
||||
action_spec.rotation_delta = tensor_spec.BoundedTensorSpec(
|
||||
(3,),
|
||||
dtype=tf.float32,
|
||||
minimum=-np.pi / 2.,
|
||||
maximum=np.pi / 2.,
|
||||
name='rotation_delta')
|
||||
|
||||
action_spec.gripper_closedness_action = tensor_spec.BoundedTensorSpec(
|
||||
(1,),
|
||||
dtype=tf.float32,
|
||||
minimum=-1.,
|
||||
maximum=1.,
|
||||
name='gripper_closedness_action')
|
||||
|
||||
num_sub_action_space = 2
|
||||
action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
|
||||
(num_sub_action_space,),
|
||||
dtype=tf.int32,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
name='terminate_episode')
|
||||
|
||||
tokenizer = action_tokenizer.RT1ActionTokenizer(
|
||||
action_spec,
|
||||
vocab_size=1024,
|
||||
action_order=[
|
||||
'terminate_episode', 'world_vector', 'rotation_delta',
|
||||
'gripper_closedness_action'
|
||||
])
|
||||
self.assertEqual(8, tokenizer.tokens_per_action)
|
||||
|
||||
# Repeat the following test N times with fuzzy inputs.
|
||||
n_repeat = 10
|
||||
for _ in range(n_repeat):
|
||||
action = tensorspec_utils.TensorSpecStruct(
|
||||
world_vector=np.random.uniform(low=-1., high=1.0, size=3),
|
||||
rotation_delta=np.random.uniform(
|
||||
low=-np.pi / 2., high=np.pi / 2., size=3),
|
||||
gripper_closedness_action=np.random.uniform(low=0., high=1.0, size=1),
|
||||
terminate_episode=[0, 1])
|
||||
action_tokens = tokenizer.tokenize(action)
|
||||
policy_action = tokenizer.detokenize(action_tokens)
|
||||
|
||||
for k in action:
|
||||
self.assertSequenceAlmostEqual(
|
||||
action[k], policy_action[k].numpy(), places=2)
|
||||
|
||||
# Repeat the test with batched actions
|
||||
batched_action = tensorspec_utils.TensorSpecStruct(
|
||||
world_vector=[
|
||||
np.random.uniform(low=-1., high=1.0, size=3),
|
||||
np.random.uniform(low=-1., high=1.0, size=3)
|
||||
],
|
||||
rotation_delta=[
|
||||
np.random.uniform(low=-np.pi / 2., high=np.pi / 2., size=3),
|
||||
np.random.uniform(low=-np.pi / 2., high=np.pi / 2., size=3)
|
||||
],
|
||||
gripper_closedness_action=[
|
||||
np.random.uniform(low=0., high=1.0, size=1),
|
||||
np.random.uniform(low=0., high=1.0, size=1)
|
||||
],
|
||||
terminate_episode=[[0, 1], [1, 0]])
|
||||
action_tokens = tokenizer.tokenize(batched_action)
|
||||
policy_action = tokenizer.detokenize(action_tokens)
|
||||
|
||||
for k in batched_action:
|
||||
for a, policy_a in zip(batched_action[k], policy_action[k].numpy()):
|
||||
self.assertSequenceAlmostEqual(a, policy_a, places=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A FiLM Efficientnet contextual image tokenizer used in Robotics Transformer 1.
|
||||
"""
|
||||
from typing import Optional
|
||||
from robotics_transformer.film_efficientnet import pretrained_efficientnet_encoder
|
||||
from robotics_transformer.tokenizers import token_learner
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class RT1ImageTokenizer(tf.keras.layers.Layer):
|
||||
"""Tokenizes based on vocab size."""
|
||||
|
||||
def __init__(self,
|
||||
embedding_output_dim: int,
|
||||
use_token_learner: bool = False,
|
||||
num_tokens: int = 8,
|
||||
**kwargs):
|
||||
"""Instantiates a RT1ImageTokenizer.
|
||||
|
||||
Args:
|
||||
embedding_output_dim: The output size of the tokens.
|
||||
use_token_learner: Whether to use token learner. See
|
||||
https://arxiv.org/abs/2106.11297
|
||||
num_tokens: Relevant only for token learner - the number of learned
|
||||
tokens.
|
||||
**kwargs: Keyword arguments to base class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._embedding_output_dim = embedding_output_dim
|
||||
|
||||
self._tokenizer = pretrained_efficientnet_encoder.EfficientNetEncoder(
|
||||
pooling=False, early_film=True)
|
||||
|
||||
self._use_token_learner = use_token_learner
|
||||
if self._use_token_learner:
|
||||
self._num_tokens = num_tokens
|
||||
self._token_learner = token_learner.TokenLearnerModule(
|
||||
num_tokens=self._num_tokens)
|
||||
|
||||
@property
|
||||
def tokens_per_context_image(self) -> int:
|
||||
if self._use_token_learner:
|
||||
num_tokens = self._num_tokens
|
||||
else:
|
||||
num_tokens = 81
|
||||
return num_tokens
|
||||
|
||||
def __call__(self,
|
||||
image: tf.Tensor,
|
||||
context: Optional[tf.Tensor] = None,
|
||||
training: bool = False) -> tf.Tensor:
|
||||
"""Gets image tokens.
|
||||
|
||||
Args:
|
||||
image: Images of shape (b, t, h, w, 3) to tokenize.
|
||||
context: An optional context vector (e.g., a natural language embedding).
|
||||
Expected to have shape (b, t, embedding_dim).
|
||||
training: Whether or not we are in training mode.
|
||||
|
||||
Returns:
|
||||
tokens: has shape (batch, t, num_tokens_per_timestep, embedding_dim)
|
||||
"""
|
||||
image_shape = tf.shape(image)
|
||||
b = image_shape[0]
|
||||
t = image_shape[1]
|
||||
h = image_shape[2]
|
||||
w = image_shape[3]
|
||||
c = image_shape[4]
|
||||
|
||||
# Fold the time axis into the batch axis.
|
||||
image = tf.reshape(image, [b * t, h, w, c])
|
||||
if context is not None:
|
||||
context_rank = tf.rank(context)
|
||||
assertion = tf.Assert(context_rank == 3, data=[context_rank])
|
||||
with tf.control_dependencies([assertion]):
|
||||
context = tf.reshape(context, [b * t, tf.shape(context)[-1]])
|
||||
tokens = self.get_image_embeddings(image, context, training)
|
||||
if self._use_token_learner:
|
||||
tokens = self._token_learner(tokens, training)
|
||||
# Unflatten the time axis, which was previously flattened into the batch.
|
||||
tokens = tf.reshape(tokens, [b, t, tf.shape(tokens)[1], -1])
|
||||
return tokens
|
||||
|
||||
def get_image_embeddings(self,
|
||||
image: tf.Tensor,
|
||||
context: Optional[tf.Tensor],
|
||||
training: bool = False) -> tf.Tensor:
|
||||
"""Gets embeddings from image.
|
||||
|
||||
Args:
|
||||
image: Expected to be float32 in range [0, 1] with shape (b, h, w, 3).
|
||||
context: Expected to be float32 with shape (b, embedding_dim)
|
||||
training: Whether or not we are in training mode.
|
||||
|
||||
Returns:
|
||||
tokens of shape (b, num_tokens, emedding_dim)
|
||||
"""
|
||||
image_tokens = self._tokenizer(image, context=context, training=training)
|
||||
image_tokens = tf.reshape(image_tokens, [-1, 81, 512])
|
||||
return image_tokens
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for image_tokenizer."""
|
||||
from absl.testing import parameterized
|
||||
from robotics_transformer.tokenizers import image_tokenizer
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ImageTokenizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sample_image', 512, 224, False, 8),
|
||||
('sample_image_token_learner', 512, 224, True, 8))
|
||||
def testTokenize(self, output_dim, image_resolution, use_token_learner,
|
||||
num_tokens):
|
||||
batch = 1
|
||||
seq = 2
|
||||
tokenizer = image_tokenizer.RT1ImageTokenizer(
|
||||
embedding_output_dim=output_dim,
|
||||
use_token_learner=use_token_learner,
|
||||
num_tokens=num_tokens)
|
||||
|
||||
image = tf.random.normal(
|
||||
shape=(batch, seq, image_resolution, image_resolution, 3))
|
||||
image = tf.clip_by_value(image, 0.0, 1.0)
|
||||
context_vector = tf.random.uniform((batch, seq, 512))
|
||||
image_tokens = tokenizer(image, context_vector)
|
||||
if use_token_learner:
|
||||
self.assertEqual(image_tokens.shape, [batch, seq, num_tokens, 512])
|
||||
else:
|
||||
self.assertEqual(image_tokens.shape, [batch, seq, 81, 512])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF implementation of Token Learner(Ryoo et al 2021)."""
|
||||
|
||||
import functools
|
||||
from typing import Optional, Sequence, Union
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def gelu(x: float) -> float:
|
||||
return 0.5 * x * (1 +
|
||||
tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))
|
||||
|
||||
|
||||
def _maybe_dropout(rate: float = 0.0, name: str = "dropout"):
|
||||
"""Helper function to return dropout layer if rate is non zero."""
|
||||
if rate:
|
||||
return tf.keras.layers.Dropout(rate, name=name)
|
||||
return lambda x, *args: x # Does nothing to x.
|
||||
|
||||
|
||||
class MlpBlock(tf.keras.layers.Layer):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_dim: int,
|
||||
out_dim: Optional[int] = None,
|
||||
kernel_init: Optional[tf.keras.initializers.Initializer] = tf
|
||||
.keras.initializers.glorot_uniform(),
|
||||
bias_init: Optional[tf.keras.initializers.Initializer] = tf.keras
|
||||
.initializers.RandomNormal(stddev=1e-6),
|
||||
dropout_rate: float = 0.1,
|
||||
**kwargs):
|
||||
"""Initializer for the MLP Block.
|
||||
|
||||
This computes outer_dense(gelu(hidden_dense(input))), with dropout
|
||||
applied as necessary.
|
||||
|
||||
Note: Especially outside a keras workflow, make sure to call layer.build
|
||||
|
||||
Args:
|
||||
mlp_dim: The dimension of the inner representation (output of hidden
|
||||
layer). Usually larger than the input/output dim.
|
||||
out_dim: The output dimension of the block. If None, the model output dim
|
||||
is equal to the input dim (usually desired)
|
||||
kernel_init: Initializer for dense kernels, used for both dense layers.
|
||||
bias_init: Initializer for dense biases, used for both dense layers.
|
||||
dropout_rate: Dropout rate to be applied after dense ( & activation)
|
||||
**kwargs: Other keyword args passed to the tf.keras.layers.Layer
|
||||
constructor e.g. the name
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._out_dim = out_dim
|
||||
self._hidden_dropout = _maybe_dropout(dropout_rate)
|
||||
self._output_dropout = _maybe_dropout(dropout_rate)
|
||||
self._hidden_layer = tf.keras.layers.Dense(
|
||||
mlp_dim,
|
||||
activation=gelu,
|
||||
kernel_initializer=kernel_init,
|
||||
bias_initializer=bias_init,
|
||||
name="hidden_dense")
|
||||
|
||||
# If out_dim is None, infer out_dim = input_dim at self.build()
|
||||
self._output_layer = functools.partial(
|
||||
tf.keras.layers.Dense,
|
||||
kernel_initializer=kernel_init,
|
||||
bias_initializer=bias_init,
|
||||
name="final_dense")
|
||||
|
||||
def build(self, input_shape: Sequence[int]):
|
||||
out_dim = self._out_dim or input_shape[-1]
|
||||
self._output_layer = self._output_layer(units=out_dim)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self,
|
||||
inputs: tf.Tensor,
|
||||
*,
|
||||
is_training: Union[bool, tf.Tensor] = False) -> tf.Tensor:
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
x = self._hidden_layer(inputs)
|
||||
x = self._hidden_dropout(x, is_training)
|
||||
x = self._output_layer(x)
|
||||
x = self._output_dropout(x, is_training)
|
||||
return x
|
||||
|
||||
|
||||
class TokenLearnerModule(tf.keras.layers.Layer):
|
||||
"""TokenLearner module V1.1 (https://arxiv.org/abs/2106.11297)."""
|
||||
|
||||
def __init__(self,
|
||||
num_tokens: int,
|
||||
bottleneck_dim: int = 64,
|
||||
dropout_rate: float = 0.):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = MlpBlock(
|
||||
mlp_dim=bottleneck_dim, out_dim=num_tokens, dropout_rate=dropout_rate)
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
||||
|
||||
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||
if len(inputs.shape) == 4:
|
||||
bs, h, w, c = inputs.shape
|
||||
inputs = tf.reshape(inputs, [bs, h * w, c])
|
||||
|
||||
selected = self.layernorm(inputs)
|
||||
|
||||
selected = self.mlp(
|
||||
selected, is_training=training) # Shape: [bs, h*w, n_token].
|
||||
|
||||
selected = tf.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w].
|
||||
selected = tf.nn.softmax(selected, axis=-1)
|
||||
|
||||
feat = tf.einsum("...si,...id->...sd", selected, inputs)
|
||||
|
||||
return feat # Shape: [bs, n_token, c]
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for token_learner."""
|
||||
from absl.testing import parameterized
|
||||
from robotics_transformer.tokenizers import token_learner
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class TokenLearnerTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(('sample_input', 512, 8))
|
||||
def testTokenLearner(self, embedding_dim, num_tokens):
|
||||
batch = 1
|
||||
seq = 2
|
||||
token_learner_layer = token_learner.TokenLearnerModule(
|
||||
num_tokens=num_tokens)
|
||||
|
||||
inputvec = tf.random.normal(shape=(batch * seq, 81, embedding_dim))
|
||||
|
||||
learnedtokens = token_learner_layer(inputvec)
|
||||
self.assertEqual(learnedtokens.shape,
|
||||
[batch * seq, num_tokens, embedding_dim])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RT1 decoder transformer.
|
||||
|
||||
Copied from:
|
||||
https://www.tensorflow.org/text/tutorials/transformer#decoder
|
||||
"""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class _TransformerLayer(tf.keras.layers.Layer):
|
||||
"""A single transformer block."""
|
||||
|
||||
def __init__(self,
|
||||
layer_size: int = 4096,
|
||||
num_heads: int = 8,
|
||||
feed_forward_size: int = 512,
|
||||
dropout_rate: float = 0.1,
|
||||
return_attention_scores: bool = False):
|
||||
"""Creates a Transformer layer.
|
||||
|
||||
Args:
|
||||
layer_size: Size of the multiple head attention layer.
|
||||
num_heads: Number of heads for the multiple head attention layer.
|
||||
feed_forward_size: Dimensionality of the feed_forward layer.
|
||||
dropout_rate: Dropout rate.
|
||||
return_attention_scores: Return attention scores.
|
||||
"""
|
||||
super(_TransformerLayer, self).__init__()
|
||||
|
||||
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
||||
self.mha1 = tf.keras.layers.MultiHeadAttention(
|
||||
key_dim=layer_size, num_heads=num_heads, dropout=dropout_rate)
|
||||
self.ff = tf.keras.layers.Dense(feed_forward_size)
|
||||
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
||||
self.dropout_ff = tf.keras.layers.Dropout(dropout_rate)
|
||||
self._return_attention_scores = return_attention_scores
|
||||
|
||||
def call(self, x: tf.Tensor, attention_mask: tf.Tensor,
|
||||
training: bool) -> Tuple[tf.Tensor, Union[tf.Tensor, None]]:
|
||||
"""Calls the layer.
|
||||
|
||||
Args:
|
||||
x: Input Tensor of shape `(B, T, dim)`.
|
||||
attention_mask: a boolean mask of shape `(B, T, T)`, that prevents
|
||||
attention to certain positions. The boolean mask specifies which query
|
||||
elements can attend to which key elements, 1 indicates attention and 0
|
||||
indicates no attention. Broadcasting can happen for the missing batch
|
||||
dimensions and the head dimension.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
|
||||
Returns:
|
||||
y: Output Tensor of shape `(B, T, dim)`. Also return the attention scores
|
||||
of shape `(B, T, dim)` or None.
|
||||
"""
|
||||
x1 = self.layernorm1(x)
|
||||
mha_results = self.mha1(
|
||||
query=x1,
|
||||
key=x1,
|
||||
value=x1,
|
||||
attention_mask=attention_mask,
|
||||
return_attention_scores=self._return_attention_scores,
|
||||
training=training)
|
||||
if self._return_attention_scores:
|
||||
x1, score = mha_results
|
||||
else:
|
||||
x1, score = mha_results, None
|
||||
|
||||
x = x + x1
|
||||
|
||||
y = self.layernorm2(x)
|
||||
ff_y = self.ff(y)
|
||||
ff_y = self.dropout_ff(ff_y, training=training)
|
||||
x = x + ff_y
|
||||
return x, score
|
||||
|
||||
|
||||
class Transformer(tf.keras.layers.Layer):
|
||||
"""A decoder only transformer."""
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 1,
|
||||
layer_size: int = 4096,
|
||||
num_heads: int = 8,
|
||||
feed_forward_size: int = 512,
|
||||
dropout_rate: float = 0.1,
|
||||
vocab_size: int = 256,
|
||||
return_attention_scores: bool = False):
|
||||
"""Creates a transformer.
|
||||
|
||||
Args:
|
||||
num_layers: Number of transformer layers.
|
||||
layer_size: Size of the multiple head attention layer.
|
||||
num_heads: Number of heads for the multiple head attention layer.
|
||||
feed_forward_size: Dimensionality of the feed_forward layer.
|
||||
dropout_rate: Dropout rate.
|
||||
vocab_size: Dimensionality of tokens from the output layer.
|
||||
return_attention_scores: Return attention scores.
|
||||
"""
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
self._layers = [
|
||||
_TransformerLayer( # pylint: disable=g-complex-comprehension
|
||||
layer_size=layer_size,
|
||||
num_heads=num_heads,
|
||||
feed_forward_size=feed_forward_size,
|
||||
dropout_rate=dropout_rate,
|
||||
return_attention_scores=return_attention_scores)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self._token_emb = tf.keras.layers.Dense(feed_forward_size)
|
||||
self._position_emb = tf.keras.layers.Dense(feed_forward_size)
|
||||
self._output_tokens = tf.keras.layers.Dense(vocab_size)
|
||||
|
||||
def call(
|
||||
self,
|
||||
x: tf.Tensor,
|
||||
training: bool,
|
||||
attention_mask: tf.Tensor,
|
||||
) -> Union[tf.Tensor, Tuple[tf.Tensor, list[tf.Tensor]]]:
|
||||
"""Calls the layer.
|
||||
|
||||
Args:
|
||||
x: Input Tensor of shape `(B, T, dim)`.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
attention_mask: a boolean mask of shape `(B, T, T)`, that prevents
|
||||
attention to certain positions. The boolean mask specifies which query
|
||||
elements can attend to which key elements, 1 indicates attention and 0
|
||||
indicates no attention. Broadcasting can happen for the missing batch
|
||||
dimensions and the head dimension.
|
||||
|
||||
Returns:
|
||||
x: Output Tensor of shape `(B, T, vocab_size)`. If
|
||||
`return_attention_scores`, also return attention scores of
|
||||
a list of `layer` of elements with shape `(B, T, dim)`.
|
||||
"""
|
||||
|
||||
seq_len = tf.shape(x)[1]
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
positions = tf.one_hot(
|
||||
tf.tile(tf.expand_dims(tf.range(0, seq_len, 1), 0), [batch_size, 1]),
|
||||
seq_len)
|
||||
|
||||
x = self._token_emb(x)
|
||||
x += self._position_emb(positions)
|
||||
scores = []
|
||||
|
||||
for layer in self._layers:
|
||||
x, score = layer(x, attention_mask=attention_mask, training=training)
|
||||
if score is not None:
|
||||
scores.append(score)
|
||||
x = self._output_tokens(x)
|
||||
return x, scores
|
|
@ -0,0 +1,689 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tensorflow based methods for sequence agents."""
|
||||
from typing import Optional, Tuple, Union, Any
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from robotics_transformer import transformer
|
||||
from robotics_transformer.film_efficientnet import preprocessors
|
||||
from robotics_transformer.tokenizers import action_tokenizer
|
||||
from robotics_transformer.tokenizers import image_tokenizer
|
||||
|
||||
from tensor2robot.utils import tensorspec_utils
|
||||
import tensorflow as tf
|
||||
from tf_agents.networks import network
|
||||
from tf_agents.specs import tensor_spec
|
||||
from tf_agents.utils import nest_utils
|
||||
|
||||
|
||||
class TransformerNetwork(network.Network):
|
||||
"""A transformer based actor network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tensor_spec: tensorspec_utils.TensorSpecStruct,
|
||||
output_tensor_spec: tensorspec_utils.TensorSpecStruct,
|
||||
train_step_counter: int = 0,
|
||||
vocab_size: int = 256,
|
||||
token_embedding_size: int = 512,
|
||||
num_layers: int = 1,
|
||||
layer_size: int = 4096,
|
||||
num_heads: int = 8,
|
||||
feed_forward_size: int = 512,
|
||||
dropout_rate: float = 0.1,
|
||||
time_sequence_length: int = 1,
|
||||
crop_size: int = 236,
|
||||
policy_info_spec: Optional[dict[Any,
|
||||
tensor_spec.BoundedTensorSpec]] = None,
|
||||
action_order: Optional[list[str]] = None,
|
||||
use_token_learner: Optional[bool] = True,
|
||||
return_attention_scores: bool = False,
|
||||
**kwargs):
|
||||
"""Creates a transformer network.
|
||||
|
||||
Args:
|
||||
input_tensor_spec: Nested list/tuple/dict of TensorSpecs, describing the
|
||||
shape of input tensor.
|
||||
output_tensor_spec: Nested list/tuple/dict of TensorSpecs, describing the
|
||||
shape of output tensor.
|
||||
train_step_counter: Counter for number of steps.
|
||||
vocab_size: Dimensionality of tokens from the output layer.
|
||||
token_embedding_size: Dimensionality of tokens from the embedding layer.
|
||||
num_layers: Number of transformer layers.
|
||||
layer_size: Size of the multiple head attention layer.
|
||||
num_heads: Number of heads for the multiple head attention layer.
|
||||
feed_forward_size: Dimensionality of the feed_forward layer.
|
||||
dropout_rate: Dropout rate.
|
||||
time_sequence_length: Length of the time sequence.
|
||||
crop_size: Height and width of the square crop, where original image will
|
||||
be padded to allow full field of view to be extracted.
|
||||
policy_info_spec: Spec on return value given return type of the return
|
||||
tokenizer.
|
||||
action_order: Order of actions for the action tokenizer.
|
||||
use_token_learner: Whether to use token learner. See
|
||||
https://arxiv.org/abs/2106.11297
|
||||
return_attention_scores: show attention scores in tensorboard.
|
||||
**kwargs: Keyword parameter arguments.
|
||||
"""
|
||||
self._input_tensor_spec = input_tensor_spec
|
||||
self._output_tensor_spec = output_tensor_spec
|
||||
self._train_step_counter = train_step_counter
|
||||
self._actions = None
|
||||
self._returns = None
|
||||
self._vocab_size = vocab_size
|
||||
self._token_embedding_size = token_embedding_size
|
||||
self._time_sequence_length = time_sequence_length
|
||||
self._crop_size = crop_size
|
||||
|
||||
self._transformer = transformer.Transformer(
|
||||
num_layers=num_layers,
|
||||
layer_size=layer_size,
|
||||
num_heads=num_heads,
|
||||
feed_forward_size=feed_forward_size,
|
||||
dropout_rate=dropout_rate,
|
||||
vocab_size=self._vocab_size,
|
||||
return_attention_scores=return_attention_scores)
|
||||
|
||||
# create tokenizers
|
||||
self._image_tokenizer = image_tokenizer.RT1ImageTokenizer(
|
||||
embedding_output_dim=self._token_embedding_size,
|
||||
use_token_learner=use_token_learner)
|
||||
self._action_tokenizer = action_tokenizer.RT1ActionTokenizer(
|
||||
output_tensor_spec,
|
||||
vocab_size=self._vocab_size,
|
||||
action_order=action_order)
|
||||
|
||||
self._tokens_per_action = self._action_tokenizer.tokens_per_action
|
||||
self._tokens_per_context_image = self._image_tokenizer.tokens_per_context_image
|
||||
# generate loss and attention masks
|
||||
self._generate_masks()
|
||||
|
||||
# define mappings to token embedding size
|
||||
self._action_token_emb = tf.keras.layers.Dense(self._token_embedding_size)
|
||||
|
||||
# define loss function
|
||||
self._loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
|
||||
self._attention_scores = []
|
||||
self._use_token_learner = use_token_learner
|
||||
|
||||
super(TransformerNetwork, self).__init__(
|
||||
input_tensor_spec=input_tensor_spec, **kwargs)
|
||||
self._state_spec = {
|
||||
# Force this to be 4 dimension due to b/254902773.
|
||||
# Otherwise can be dimension 3.
|
||||
'context_image_tokens':
|
||||
tensor_spec.TensorSpec(
|
||||
shape=(time_sequence_length, self._tokens_per_context_image, 1,
|
||||
token_embedding_size),
|
||||
dtype=tf.float32,
|
||||
name='context_image_tokens'),
|
||||
'action_tokens':
|
||||
tensor_spec.TensorSpec(
|
||||
shape=(time_sequence_length, self._tokens_per_action, 1, 1),
|
||||
dtype=tf.int32,
|
||||
name='action_tokens'),
|
||||
# Stores where in the window we are.
|
||||
# This value is within range [0, time_sequence_length + 1].
|
||||
# When seq_idx == time_sequence_length, context_image_tokens and
|
||||
# action_tokens need to be shifted to the left.
|
||||
'seq_idx':
|
||||
tensor_spec.TensorSpec(
|
||||
shape=(1, 1, 1, 1), dtype=tf.int32, name='seq_idx')
|
||||
}
|
||||
|
||||
@property
|
||||
def attention_scores(self) -> list[tf.Tensor]:
|
||||
"""Return attention score. This is for debugging/visualization purpose."""
|
||||
return self._attention_scores
|
||||
|
||||
def _get_action_index_for_token(self, k):
|
||||
"""Returns action associated with the token at given position `k`.
|
||||
|
||||
If k is not an action token then it returns -1.
|
||||
If k is part of the first action in the sequence then returns 0 etc.
|
||||
|
||||
Args:
|
||||
k: an int that represents the position in the sequence.
|
||||
|
||||
Returns:
|
||||
The index of the action that this position belongs to, or if this
|
||||
position is part of an image token then returns -1.
|
||||
"""
|
||||
if (k < 0 or k >= self._all_num_tokens):
|
||||
return -1
|
||||
|
||||
n = k
|
||||
if n % self._single_time_step_num_tokens < self._tokens_per_context_image:
|
||||
return -1
|
||||
return int(n / self._single_time_step_num_tokens)
|
||||
|
||||
def _generate_masks(self):
|
||||
"""Generate mask for action prediction loss and attention visualization."""
|
||||
# each time step = [image, action]
|
||||
self._single_time_step_num_tokens = (
|
||||
self._tokens_per_action + self._tokens_per_context_image)
|
||||
|
||||
# full sequence = [prefix context + N x timestep + postfix context]
|
||||
self._all_num_tokens = (
|
||||
self._time_sequence_length * self._single_time_step_num_tokens)
|
||||
|
||||
# create mask for action predition loss
|
||||
self._action_tokens_mask = []
|
||||
for n in range(0, self._all_num_tokens, self._single_time_step_num_tokens):
|
||||
for x in range(0, self._tokens_per_action, 1):
|
||||
self._action_tokens_mask.append(x + n + self._tokens_per_context_image)
|
||||
self._action_tokens_mask = tf.constant(
|
||||
self._action_tokens_mask, dtype=tf.int32)
|
||||
|
||||
# The look ahead mask ensures causality.
|
||||
self._default_attention_mask = tf.linalg.band_part(
|
||||
tf.ones((self._all_num_tokens, self._all_num_tokens)), -1, 0)
|
||||
|
||||
action_mask = np.ndarray(
|
||||
shape=(self._all_num_tokens, self._all_num_tokens), dtype=int)
|
||||
for i in range(self._all_num_tokens):
|
||||
for j in range(self._all_num_tokens):
|
||||
action_i = self._get_action_index_for_token(i)
|
||||
action_j = self._get_action_index_for_token(j)
|
||||
mask = 0
|
||||
if action_i != -1 and action_j != -1:
|
||||
# Ignore actions of previous steps.
|
||||
if action_j < action_i:
|
||||
mask = 1
|
||||
# If we're not auto-regression, ignore action dimensions of current
|
||||
# step.
|
||||
if (action_j == action_i and j <= i):
|
||||
mask = 1
|
||||
action_mask[i, j] = mask
|
||||
self._default_attention_mask -= action_mask
|
||||
|
||||
def _transformer_call(
|
||||
self,
|
||||
context_image_tokens: tf.Tensor,
|
||||
action_tokens: tf.Tensor,
|
||||
batch_size: int,
|
||||
training: bool,
|
||||
attention_mask: tf.Tensor,
|
||||
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
|
||||
"""Calls the transformer.
|
||||
|
||||
Args:
|
||||
context_image_tokens: Tokenized context and image in Tensor of shape `(B,
|
||||
T, num token, -1)`.
|
||||
action_tokens: Discrete action token sequence of size [8, 256].
|
||||
batch_size: Batch size as when reshaping all tokens.
|
||||
training: Whether to run the transformer in training mode.
|
||||
attention_mask: Optional bool tensor for masking transformer's attention.
|
||||
|
||||
Returns:
|
||||
Output tokens in Tensor of shape `(B, T, dim)`. If
|
||||
return_attention_scores, also return the attention scores of
|
||||
shape `(B, T, dim)`.
|
||||
"""
|
||||
input_token_sequence = self._assemble_input_token_sequence(
|
||||
context_image_tokens, action_tokens, batch_size)
|
||||
|
||||
# run transformer
|
||||
output_tokens, self._attention_scores = self._transformer(
|
||||
input_token_sequence, training, attention_mask)
|
||||
return output_tokens
|
||||
|
||||
def _get_tokens_and_mask(self,
|
||||
observations: dict[str, tf.Tensor],
|
||||
network_state: dict[str, tf.Tensor],
|
||||
training: bool = False):
|
||||
# tokenize all inputs
|
||||
context_image_tokens, network_state = self._tokenize_images(
|
||||
observations, network_state, training)
|
||||
action_tokens = self._tokenize_actions(observations, network_state)
|
||||
|
||||
# generate transformer attention mask
|
||||
attention_mask = self._default_attention_mask
|
||||
|
||||
return (context_image_tokens, action_tokens, attention_mask)
|
||||
|
||||
def _transformer_call_and_slice(self,
|
||||
*args,
|
||||
slice_start: int = 0,
|
||||
slice_length: int = 1,
|
||||
**kwargs) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
output_tokens = self._transformer_call(*args, **kwargs)
|
||||
|
||||
slice_end = slice_start + slice_length
|
||||
token_logits = output_tokens[:, slice_start:slice_end, :]
|
||||
token = tf.argmax(token_logits, axis=-1, output_type=tf.int32)
|
||||
|
||||
return token, token_logits
|
||||
|
||||
def call(self,
|
||||
observations: dict[str, tf.Tensor],
|
||||
network_state: dict[str, tf.Tensor],
|
||||
training: bool = False):
|
||||
"""Calls the transformer network.
|
||||
|
||||
Args:
|
||||
observations: Observation data including image and natural language
|
||||
embedding in dict of Tensors.
|
||||
network_state: Network state data including time step, image, action
|
||||
tokens, step number in dict of Tensors.
|
||||
training: Whether to call transformer network in training mode.
|
||||
|
||||
Returns:
|
||||
A tuple `(Detokenized output actions, network state)`.
|
||||
"""
|
||||
# used to determine training vs inference call
|
||||
# outer_rank will be 2 -> [b, t] during training and
|
||||
# outer_rank will be 1 -> [b] during inference
|
||||
outer_rank = self._get_outer_rank(observations)
|
||||
assert outer_rank in (1, 2)
|
||||
|
||||
b, t = self._get_batch_size_and_seq_len(network_state)
|
||||
|
||||
context_image_tokens, action_tokens, attention_mask = self._get_tokens_and_mask(
|
||||
observations, network_state, training)
|
||||
|
||||
self._aux_info = {'action_labels': action_tokens}
|
||||
|
||||
if outer_rank == 1: # This is an inference call
|
||||
# run transformer in loop to produce action tokens one-by-one
|
||||
# TODO(b/231896343): Document/comment more on what the following mess is.
|
||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||
action_t = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
||||
# Transformer shifts all to the left by one step by default (it's usually
|
||||
# predicting the next token as default training task...).
|
||||
transformer_shift = -1
|
||||
# We only want to get the action predicted at time_step.
|
||||
start_index = (
|
||||
transformer_shift + self._tokens_per_context_image + action_t *
|
||||
(self._single_time_step_num_tokens))
|
||||
current_action_tokens = []
|
||||
action_predictions_logits = []
|
||||
for k in range(self._tokens_per_action):
|
||||
action_index = start_index + k
|
||||
token, token_logits = self._transformer_call_and_slice(
|
||||
context_image_tokens,
|
||||
action_tokens,
|
||||
attention_mask=attention_mask,
|
||||
batch_size=b,
|
||||
training=training,
|
||||
slice_start=action_index # slicing single action dimension
|
||||
)
|
||||
action_predictions_logits.append(token_logits)
|
||||
current_action_tokens.append(token)
|
||||
# action_tokens is [b, t * self._tokens_per_action]
|
||||
action_tokens = tf.reshape(action_tokens, [b, -1])
|
||||
action_start_index = (action_t * self._tokens_per_action) + k
|
||||
action_tokens = tf.concat([
|
||||
action_tokens[:, :action_start_index], token,
|
||||
action_tokens[:, action_start_index + 1:]
|
||||
],
|
||||
axis=1)
|
||||
# action_tokens is [b, t, self._tokens_per_action]
|
||||
action_tokens = tf.reshape(action_tokens,
|
||||
[b, t, self._tokens_per_action])
|
||||
self._aux_info.update({
|
||||
# action_predictions_logits is
|
||||
# [b, self._tokens_per_action, self._vocab_size]
|
||||
'action_predictions_logits': tf.concat(action_predictions_logits, 1)
|
||||
})
|
||||
# predicted_tokens_for_output is [b, self._tokens_per_action]
|
||||
predicted_tokens_for_output = tf.concat(current_action_tokens, 1)
|
||||
# state_action_tokens is [b, 1, self._tokens_per_action, 1, 1]
|
||||
one_state_action_tokens = predicted_tokens_for_output[:, tf.newaxis, :,
|
||||
tf.newaxis,
|
||||
tf.newaxis]
|
||||
|
||||
state_action_tokens = network_state['action_tokens']
|
||||
network_state['action_tokens'] = tf.concat([
|
||||
state_action_tokens[:, :action_t, ...], one_state_action_tokens,
|
||||
state_action_tokens[:, action_t + 1:, ...]
|
||||
],
|
||||
axis=1)
|
||||
# Increment the time_step for the next inference call.
|
||||
network_state['seq_idx'] = tf.reshape(
|
||||
tf.minimum(seq_idx + 1, self._time_sequence_length), [-1, 1, 1, 1, 1])
|
||||
|
||||
self._loss = tf.constant(0.0)
|
||||
else:
|
||||
# training call --> simply run one transformer forward pass
|
||||
output_tokens = self._transformer_call(
|
||||
context_image_tokens,
|
||||
action_tokens,
|
||||
attention_mask=attention_mask,
|
||||
batch_size=b,
|
||||
training=training)
|
||||
|
||||
# Gather all predicted actions for the action loss.
|
||||
action_logits = tf.gather(
|
||||
output_tokens, self._action_tokens_mask - 1, axis=1)
|
||||
action_logits_for_training = tf.reshape(
|
||||
action_logits, [b, t, self._tokens_per_action, -1])
|
||||
|
||||
# Only take the last action as the action.
|
||||
# action_logits_for_output is [b, self._tokens_per_action, emb]
|
||||
action_logits_for_output = action_logits_for_training[:, -1]
|
||||
|
||||
# predicted_tokens_for_output is [b, self._tokens_per_action]
|
||||
predicted_tokens_for_output = tf.argmax(
|
||||
action_logits_for_output, axis=-1, output_type=tf.int32)
|
||||
|
||||
num_items = (
|
||||
tf.cast(b * t, tf.float32) * self._single_time_step_num_tokens)
|
||||
action_loss = tf.reduce_mean(
|
||||
self._loss_object(action_tokens, action_logits_for_training) /
|
||||
num_items,
|
||||
axis=-1)
|
||||
|
||||
self._loss = action_loss
|
||||
|
||||
# store action labels and predictions for visualization
|
||||
self._aux_info.update({
|
||||
'action_predictions':
|
||||
tf.argmax(
|
||||
action_logits_for_training, axis=-1, output_type=tf.int32),
|
||||
'action_loss':
|
||||
action_loss,
|
||||
'actor_loss_mask':
|
||||
tf.ones([b], dtype=tf.float32)
|
||||
})
|
||||
|
||||
output_actions = self._action_tokenizer.detokenize(
|
||||
predicted_tokens_for_output)
|
||||
return output_actions, network_state
|
||||
|
||||
def add_summaries(self, observations: dict[str, tf.Tensor],
|
||||
logging_info: dict[str, tf.Tensor], debug_summaries: bool,
|
||||
training: bool) -> None:
|
||||
"""Adds summaries.
|
||||
|
||||
Args:
|
||||
observations: Observation data including image and natural language
|
||||
instruction in dict of Tensors.
|
||||
logging_info: Dict with all data stored for logging during training pass.
|
||||
debug_summaries: Whether to include debug summaries.
|
||||
training: Whether this function is called during training or inference.
|
||||
"""
|
||||
num_params = 0
|
||||
for weight in self.trainable_weights:
|
||||
weight_params = 1
|
||||
for dim in weight.shape:
|
||||
weight_params *= dim
|
||||
num_params += weight_params
|
||||
tf.compat.v2.summary.scalar(name='num_params', data=num_params)
|
||||
# debug_summaries are for the non-tpu worker, train_summary.
|
||||
if debug_summaries:
|
||||
image = observations['image'] # [b, t, h, w, c]
|
||||
image_h = image.shape[2]
|
||||
image_w = image.shape[3]
|
||||
batch_size = image.shape[0]
|
||||
num_ts = image.shape[1]
|
||||
logging.info('image shape %s', image.shape)
|
||||
# Concat images for different timesteps across width.
|
||||
image = tf.concat(tf.unstack(image, axis=1), 2)
|
||||
# Concat images for different batches (up to 8) across height.
|
||||
image = tf.expand_dims(tf.concat(tf.unstack(image, axis=0)[0:8], 0), 0)
|
||||
tf.summary.image(
|
||||
'observations/image',
|
||||
image,
|
||||
step=self._train_step_counter,
|
||||
# Single output since we have concatenated images along batch.
|
||||
max_outputs=1)
|
||||
|
||||
# [b, t], strings
|
||||
if 'natural_language_instruction' in observations:
|
||||
task = observations['natural_language_instruction'][:, 0]
|
||||
tf.summary.text(
|
||||
'natural_language_instruction', task, step=self._train_step_counter)
|
||||
if self.attention_scores and not self._use_token_learner:
|
||||
for l_idx, layer_attention_score in enumerate(self.attention_scores):
|
||||
logging.info('Attention score shape: %s, %s', l_idx,
|
||||
layer_attention_score.shape)
|
||||
for head_idx in range(layer_attention_score.shape[1]):
|
||||
pairwise_attention = tf.expand_dims(
|
||||
layer_attention_score[:, head_idx], -1)
|
||||
# pairwise attention shape (16, 552, 552, 1)
|
||||
# make attention from different time steps comparable
|
||||
pairwise_attention = pairwise_attention * np.arange(
|
||||
1, pairwise_attention.shape[1] + 1)[None, :, None, None]
|
||||
|
||||
# visualize spatial attention, note this only supports
|
||||
# mk1_500tasks_transformer pipeline with no token learner
|
||||
img_tf_ts = tf.reshape(
|
||||
tf.transpose(
|
||||
tf.reshape(
|
||||
tf.reduce_sum(pairwise_attention, axis=1) / np.arange(
|
||||
pairwise_attention.shape[1], 0, -1)[None, :, None],
|
||||
[batch_size, num_ts, -1]),
|
||||
[0, 2, 1])[:, :-self._tokens_per_action, :],
|
||||
[-1, 9, 9, num_ts])
|
||||
|
||||
img_tf_ts = tf.image.resize(
|
||||
img_tf_ts, [image_h, image_w],
|
||||
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
img_tf_ts_concat = tf.concat(tf.unstack(img_tf_ts, axis=3), 2)
|
||||
img_tf_ts_concat_min = tf.reduce_min(
|
||||
img_tf_ts_concat, axis=[1, 2], keepdims=True)
|
||||
img_tf_ts_concat = (img_tf_ts_concat - img_tf_ts_concat_min) / (
|
||||
tf.reduce_max(img_tf_ts_concat, axis=[1, 2], keepdims=True) -
|
||||
img_tf_ts_concat_min)
|
||||
img_tf_ts_concat = tf.concat(
|
||||
tf.unstack(img_tf_ts_concat, axis=0)[:8], 0)
|
||||
img_tf_ts_concat = tf.expand_dims(
|
||||
tf.expand_dims(img_tf_ts_concat, 0), -1)
|
||||
tf.summary.image(
|
||||
'attention/layer_{}/head_{}'.format(l_idx, head_idx),
|
||||
img_tf_ts_concat,
|
||||
step=self._train_step_counter,
|
||||
# Single output since we have concatenated images along batch.
|
||||
max_outputs=1)
|
||||
|
||||
if img_tf_ts_concat.shape[1] == image.shape[
|
||||
1] and img_tf_ts_concat.shape[2] == image.shape[2]:
|
||||
# can overlay
|
||||
overlay_viz = tf.cast(
|
||||
(tf.cast(image, tf.float32) * (0.2 + img_tf_ts_concat) / 1.2),
|
||||
tf.uint8)
|
||||
tf.summary.image(
|
||||
'overlay_attention/layer_{}/head_{}'.format(l_idx, head_idx),
|
||||
overlay_viz,
|
||||
step=self._train_step_counter,
|
||||
# Single output since we have concatenated images along batch.
|
||||
max_outputs=1)
|
||||
|
||||
# log action info
|
||||
action_labels = tf.boolean_mask(logging_info['action_labels'],
|
||||
logging_info['actor_loss_mask'])
|
||||
action_predictions = tf.boolean_mask(logging_info['action_predictions'],
|
||||
logging_info['actor_loss_mask'])
|
||||
with tf.name_scope('ActionTokens'):
|
||||
token_accuracy = (
|
||||
tf.cast(tf.equal(action_labels, action_predictions), tf.float32))
|
||||
accuracy = tf.reduce_mean(token_accuracy)
|
||||
tf.compat.v2.summary.scalar(
|
||||
name='accuracy', data=accuracy, step=self._train_step_counter)
|
||||
# Accuracy across timesteps
|
||||
for t in range(self._time_sequence_length):
|
||||
tf.compat.v2.summary.scalar(
|
||||
name='accuracy/time_step/{}'.format(t),
|
||||
data=tf.reduce_mean(token_accuracy[:, t, :]),
|
||||
step=self._train_step_counter)
|
||||
token_index = 0
|
||||
for k in self._action_tokenizer.action_order:
|
||||
spec = self._action_tokenizer.action_spec[k]
|
||||
if spec.dtype == tf.int32:
|
||||
n_tokens = 1
|
||||
else:
|
||||
n_tokens = spec.shape[0]
|
||||
action_token_accuracy = tf.reduce_mean(
|
||||
token_accuracy[:, :, token_index:token_index + n_tokens])
|
||||
tf.compat.v2.summary.scalar(
|
||||
name='accuracy/action_type/{}'.format(k),
|
||||
data=action_token_accuracy,
|
||||
step=self._train_step_counter)
|
||||
for n in range(n_tokens):
|
||||
tf.summary.histogram(
|
||||
'tokens/{}_{}/labels'.format(k, n + 1),
|
||||
action_labels[:, :, token_index],
|
||||
step=self._train_step_counter)
|
||||
tf.summary.histogram(
|
||||
'tokens/{}_{}/predictions'.format(k, n + 1),
|
||||
action_predictions[:, :, token_index],
|
||||
step=self._train_step_counter)
|
||||
token_index += 1
|
||||
|
||||
# log loss components
|
||||
with tf.name_scope('TokenLosses'):
|
||||
tf.compat.v2.summary.scalar(
|
||||
name='action_loss',
|
||||
data=tf.reduce_mean(logging_info['action_loss']),
|
||||
step=self._train_step_counter)
|
||||
|
||||
def _tokenize_images(self, observations, network_state, training):
|
||||
image = observations['image'] # [b, t, h, w, c]
|
||||
outer_rank = self._get_outer_rank(observations)
|
||||
if outer_rank == 1: # This is an inference call
|
||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||
time_step = tf.minimum(seq_idx, self._time_sequence_length - 1)
|
||||
image = tf.expand_dims(image, 1)
|
||||
|
||||
# TODO(b/255731285)
|
||||
image_shape = tf.shape(image)
|
||||
b = image_shape[0]
|
||||
input_t = image_shape[1]
|
||||
h = image_shape[2]
|
||||
w = image_shape[3]
|
||||
c = image_shape[4]
|
||||
|
||||
context = self._extract_context_from_observation(observations, input_t)
|
||||
|
||||
image = tf.reshape(image, [b * input_t, h, w, c])
|
||||
seed = tf.random.uniform(shape=(2,), maxval=2**30, dtype=tf.int32)
|
||||
image = preprocessors.convert_dtype_and_crop_images(
|
||||
image,
|
||||
crop_size=self._crop_size,
|
||||
training=training,
|
||||
pad_then_crop=True,
|
||||
convert_dtype=True,
|
||||
seed=seed)
|
||||
image = tf.reshape(image, [b, input_t, h, w, c])
|
||||
context_image_tokens = self._image_tokenizer(
|
||||
image, context=context, training=training)
|
||||
num_tokens = tf.shape(context_image_tokens)[2]
|
||||
context_image_tokens = tf.reshape(context_image_tokens,
|
||||
[b, input_t, num_tokens, 1, -1])
|
||||
if outer_rank == 1: # This is an inference call
|
||||
network_state['context_image_tokens'] = tf.reshape(
|
||||
network_state['context_image_tokens'], [
|
||||
b, self._time_sequence_length, self._tokens_per_context_image, 1,
|
||||
-1
|
||||
])
|
||||
state_image_tokens = network_state['context_image_tokens']
|
||||
# network_state as input for this call is the output from the last call.
|
||||
# Therefore, we need to shift all images to the left by 1 in the time axis
|
||||
# to align w/ the time dim in this call.
|
||||
state_image_tokens = tf.cond(
|
||||
seq_idx == self._time_sequence_length,
|
||||
lambda: tf.roll(state_image_tokens, -1, axis=1),
|
||||
lambda: state_image_tokens)
|
||||
|
||||
context_image_tokens = tf.concat([
|
||||
state_image_tokens[:, :time_step, ...], context_image_tokens,
|
||||
state_image_tokens[:, time_step + 1:, ...]
|
||||
],
|
||||
axis=1)
|
||||
network_state['context_image_tokens'] = context_image_tokens
|
||||
|
||||
return context_image_tokens, network_state
|
||||
|
||||
def _tokenize_actions(self, observations, network_state):
|
||||
outer_rank = self._get_outer_rank(observations)
|
||||
if outer_rank == 1: # This is an inference call
|
||||
# TODO(b/231896343): Clarify what is going on with the network state
|
||||
# tensors, currently they all have to be the same n_dims so we have to
|
||||
# add/remove dummy dims.
|
||||
action_tokens = tf.squeeze(network_state['action_tokens'], [3, 4])
|
||||
seq_idx = tf.reshape(network_state['seq_idx'], [1])[0]
|
||||
# network_state as input for this call is the output from the last call.
|
||||
# Therefore, we need to shift all actions by 1 to the left.
|
||||
action_tokens = tf.cond(seq_idx == self._time_sequence_length,
|
||||
lambda: tf.roll(action_tokens, -1, axis=1),
|
||||
lambda: action_tokens)
|
||||
else:
|
||||
assert outer_rank == 2
|
||||
if self._actions is None:
|
||||
b, t = self._get_batch_size_and_seq_len(network_state)
|
||||
action_tokens = tf.zeros(
|
||||
shape=[b, t, self._tokens_per_action], dtype=tf.int32)
|
||||
else:
|
||||
action_tokens = self._action_tokenizer.tokenize(self._actions)
|
||||
return action_tokens
|
||||
|
||||
def _assemble_input_token_sequence(self, context_image_tokens, action_tokens,
|
||||
batch_size):
|
||||
# embed action tokens
|
||||
action_tokens = tf.one_hot(action_tokens, self._vocab_size)
|
||||
action_tokens = self._action_token_emb(action_tokens)
|
||||
action_tokens = tf.zeros_like(action_tokens) # b/260260205
|
||||
|
||||
# Because of b/254902773, we need to add 1 extra dimension.
|
||||
action_tokens = tf.expand_dims(action_tokens, axis=-2)
|
||||
|
||||
# assemble token sequence
|
||||
input_token_sequence = tf.concat([context_image_tokens, action_tokens],
|
||||
axis=2)
|
||||
|
||||
input_token_sequence = tf.reshape(
|
||||
input_token_sequence, [batch_size, -1, self._token_embedding_size])
|
||||
return input_token_sequence
|
||||
|
||||
def _extract_context_from_observation(self, observations, seq_len):
|
||||
"""Extract context from observation."""
|
||||
context = None
|
||||
if 'natural_language_embedding' in observations:
|
||||
outer_rank = self._get_outer_rank(observations)
|
||||
context = observations['natural_language_embedding'] # [b, t, emb-size]
|
||||
if outer_rank == 1:
|
||||
context = tf.tile(context[:, None], [1, seq_len, 1])
|
||||
return context
|
||||
|
||||
def set_actions(self, actions: tensorspec_utils.TensorSpecStruct):
|
||||
"""Sets actions that will be tokenized and used in transformer network.
|
||||
|
||||
Args:
|
||||
actions: actions to be tokenized and used in transformer network. example
|
||||
actions are terminate = [0, 1] world_vector = [0.9, 0.8, -0.3]
|
||||
rotation_delta = [-0.1, 0.2, .6] gripper_closedness = 0.9
|
||||
"""
|
||||
self._actions = actions
|
||||
|
||||
def _get_outer_rank(self, observations):
|
||||
# used to determine training vs inference call
|
||||
# outer_rank will be 2 -> [b, t] during training and
|
||||
# outer_rank will be 1 -> [b] during inference
|
||||
return nest_utils.get_outer_rank(observations, self._input_tensor_spec)
|
||||
|
||||
def _get_batch_size_and_seq_len(self, network_state):
|
||||
image_shape = tf.shape(network_state['context_image_tokens'])
|
||||
b = image_shape[0]
|
||||
t = image_shape[1]
|
||||
return b, t
|
||||
|
||||
def get_actor_loss(self) -> tf.Tensor:
|
||||
return self._loss
|
||||
|
||||
def get_aux_info(self) -> dict[str, Any]:
|
||||
return self._aux_info
|
|
@ -0,0 +1,229 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for networks."""
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from robotics_transformer import transformer_network
|
||||
from robotics_transformer.transformer_network_test_set_up import BATCH_SIZE
|
||||
from robotics_transformer.transformer_network_test_set_up import NAME_TO_INF_OBSERVATIONS
|
||||
from robotics_transformer.transformer_network_test_set_up import NAME_TO_STATE_SPECS
|
||||
from robotics_transformer.transformer_network_test_set_up import observations_list
|
||||
from robotics_transformer.transformer_network_test_set_up import spec_names_list
|
||||
from robotics_transformer.transformer_network_test_set_up import state_spec_list
|
||||
from robotics_transformer.transformer_network_test_set_up import TIME_SEQUENCE_LENGTH
|
||||
from robotics_transformer.transformer_network_test_set_up import TransformerNetworkTestUtils
|
||||
|
||||
import tensorflow as tf
|
||||
from tf_agents.specs import tensor_spec
|
||||
|
||||
|
||||
class TransformerNetworkTest(TransformerNetworkTestUtils):
|
||||
|
||||
# pylint:disable=g-complex-comprehension
|
||||
@parameterized.named_parameters([{
|
||||
'testcase_name': '_' + name,
|
||||
'state_spec': spec,
|
||||
'train_observation': obs,
|
||||
} for (name, spec,
|
||||
obs) in zip(spec_names_list(), state_spec_list(), observations_list())]
|
||||
)
|
||||
# pylint:enable=g-complex-comprehension
|
||||
def testTransformerTrainLossCall(self, state_spec, train_observation):
|
||||
network = transformer_network.TransformerNetwork(
|
||||
input_tensor_spec=state_spec,
|
||||
output_tensor_spec=self._action_spec,
|
||||
time_sequence_length=TIME_SEQUENCE_LENGTH)
|
||||
|
||||
network.create_variables()
|
||||
self.assertNotEmpty(network.variables)
|
||||
|
||||
network.set_actions(self._train_action)
|
||||
network_state = tensor_spec.sample_spec_nest(
|
||||
network.state_spec, outer_dims=[BATCH_SIZE])
|
||||
output_actions, network_state = network(
|
||||
train_observation, step_type=None, network_state=network_state)
|
||||
expected_shape = [2, 3]
|
||||
self.assertEqual(network.get_actor_loss().shape,
|
||||
tf.TensorShape(expected_shape))
|
||||
self.assertCountEqual(self._train_action.keys(), output_actions.keys())
|
||||
|
||||
# pylint:disable=g-complex-comprehension
|
||||
@parameterized.named_parameters([{
|
||||
'testcase_name': '_' + name,
|
||||
'spec_name': name,
|
||||
} for name in spec_names_list()])
|
||||
# pylint:enable=g-complex-comprehension
|
||||
def testTransformerInferenceLossCall(self, spec_name):
|
||||
state_spec = NAME_TO_STATE_SPECS[spec_name]
|
||||
observation = NAME_TO_INF_OBSERVATIONS[spec_name]
|
||||
|
||||
network = transformer_network.TransformerNetwork(
|
||||
input_tensor_spec=state_spec,
|
||||
output_tensor_spec=self._action_spec,
|
||||
time_sequence_length=TIME_SEQUENCE_LENGTH,
|
||||
action_order=[
|
||||
'terminate_episode', 'world_vector', 'rotation_delta',
|
||||
'gripper_closedness_action'
|
||||
])
|
||||
network.create_variables()
|
||||
self.assertNotEmpty(network.variables)
|
||||
|
||||
network.set_actions(self._inference_action)
|
||||
# inference currently only support batch size of 1
|
||||
network_state = tensor_spec.sample_spec_nest(
|
||||
network.state_spec, outer_dims=[1])
|
||||
|
||||
output_actions, network_state = network(
|
||||
observation, step_type=None, network_state=network_state)
|
||||
|
||||
tf.debugging.assert_equal(network.get_actor_loss(), 0.0)
|
||||
self.assertCountEqual(self._inference_action.keys(), output_actions.keys())
|
||||
|
||||
# pylint:disable=g-complex-comprehension
|
||||
@parameterized.named_parameters([{
|
||||
'testcase_name': '_' + name,
|
||||
'state_spec': spec,
|
||||
'train_observation': obs,
|
||||
} for name, spec, obs in zip(spec_names_list(), state_spec_list(),
|
||||
observations_list())])
|
||||
# pylint:enable=g-complex-comprehension
|
||||
def testTransformerLogging(self, state_spec, train_observation):
|
||||
network = transformer_network.TransformerNetwork(
|
||||
input_tensor_spec=state_spec,
|
||||
output_tensor_spec=self._action_spec,
|
||||
time_sequence_length=TIME_SEQUENCE_LENGTH,
|
||||
action_order=[
|
||||
'terminate_episode', 'world_vector', 'rotation_delta',
|
||||
'gripper_closedness_action'
|
||||
])
|
||||
|
||||
network.create_variables()
|
||||
self.assertNotEmpty(network.variables)
|
||||
|
||||
network.set_actions(self._train_action)
|
||||
network_state = tensor_spec.sample_spec_nest(
|
||||
network.state_spec, outer_dims=[BATCH_SIZE])
|
||||
_ = network(train_observation, step_type=None, network_state=network_state)
|
||||
network.add_summaries(
|
||||
train_observation,
|
||||
network.get_aux_info(),
|
||||
debug_summaries=True,
|
||||
training=True)
|
||||
|
||||
# pylint:disable=g-complex-comprehension
|
||||
@parameterized.named_parameters([{
|
||||
'testcase_name': '_' + name,
|
||||
'state_spec': spec,
|
||||
} for name, spec in zip(spec_names_list(), state_spec_list())])
|
||||
# pylint:enable=g-complex-comprehension
|
||||
def testTransformerCausality(self, state_spec):
|
||||
"""Tests the causality for the transformer.
|
||||
|
||||
Args:
|
||||
state_spec: Which state spec to test the transformer with
|
||||
"""
|
||||
network = transformer_network.TransformerNetwork(
|
||||
input_tensor_spec=state_spec,
|
||||
output_tensor_spec=self._action_spec,
|
||||
time_sequence_length=TIME_SEQUENCE_LENGTH)
|
||||
network.create_variables()
|
||||
self.assertNotEmpty(network.variables)
|
||||
|
||||
time_sequence_length = network._time_sequence_length
|
||||
tokens_per_image = network._tokens_per_context_image
|
||||
tokens_per_action = network._tokens_per_action
|
||||
|
||||
def _split_image_and_action_tokens(all_tokens):
|
||||
image_start_indices = [(tokens_per_image + tokens_per_action) * k
|
||||
for k in range(time_sequence_length)]
|
||||
image_tokens = tf.stack(
|
||||
[all_tokens[i:i + tokens_per_image] for i in image_start_indices],
|
||||
axis=0)
|
||||
action_start_indices = [i + tokens_per_image for i in image_start_indices]
|
||||
action_tokens = [
|
||||
tf.stack([
|
||||
all_tokens[i:i + tokens_per_action] for i in action_start_indices
|
||||
], 0)
|
||||
]
|
||||
image_tokens = tf.one_hot(image_tokens, network._token_embedding_size)
|
||||
# Remove extra dimension before the end once b/254902773 is fixed.
|
||||
shape = image_tokens.shape
|
||||
# Add batch dimension.
|
||||
image_tokens = tf.reshape(image_tokens,
|
||||
[1] + shape[:-1] + [1] + shape[-1:])
|
||||
return image_tokens, action_tokens
|
||||
|
||||
# Generate some random tokens for image and actions.
|
||||
all_tokens = tf.random.uniform(
|
||||
shape=[time_sequence_length * (tokens_per_image + tokens_per_action)],
|
||||
dtype=tf.int32,
|
||||
maxval=10,
|
||||
minval=0)
|
||||
context_image_tokens, action_tokens = _split_image_and_action_tokens(
|
||||
all_tokens)
|
||||
# Get the output tokens without any zeroed out input tokens.
|
||||
output_tokens = network._transformer_call(
|
||||
context_image_tokens=context_image_tokens,
|
||||
action_tokens=action_tokens,
|
||||
attention_mask=network._default_attention_mask,
|
||||
batch_size=1,
|
||||
training=False)[0]
|
||||
|
||||
for t in range(time_sequence_length *
|
||||
(tokens_per_image + tokens_per_action)):
|
||||
# Zero out future input tokens.
|
||||
all_tokens_at_t = tf.concat(
|
||||
[all_tokens[:t + 1],
|
||||
tf.zeros_like(all_tokens[t + 1:])], 0)
|
||||
context_image_tokens, action_tokens = _split_image_and_action_tokens(
|
||||
all_tokens_at_t)
|
||||
# Get the output tokens with zeroed out input tokens after t.
|
||||
output_tokens_at_t = network._transformer_call(
|
||||
context_image_tokens=context_image_tokens,
|
||||
action_tokens=action_tokens,
|
||||
attention_mask=network._default_attention_mask,
|
||||
batch_size=1,
|
||||
training=False)[0]
|
||||
# The output token is unchanged if future input tokens are zeroed out.
|
||||
self.assertAllEqual(output_tokens[:t + 1], output_tokens_at_t[:t + 1])
|
||||
|
||||
def testLossMasks(self):
|
||||
self._define_specs()
|
||||
self._create_agent()
|
||||
image_tokens = 3
|
||||
action_tokens = 2
|
||||
self._agent._actor_network._time_sequence_length = 2
|
||||
self._agent._actor_network._tokens_per_context_image = image_tokens
|
||||
self._agent._actor_network._tokens_per_action = action_tokens
|
||||
self._agent._actor_network._generate_masks()
|
||||
self.assertAllEqual(
|
||||
self._agent._actor_network._action_tokens_mask,
|
||||
tf.constant([
|
||||
image_tokens, image_tokens + 1, 2 * image_tokens + action_tokens,
|
||||
2 * image_tokens + action_tokens + 1
|
||||
], tf.int32))
|
||||
self._agent._actor_network._generate_masks()
|
||||
self.assertAllEqual(
|
||||
self._agent._actor_network._action_tokens_mask,
|
||||
tf.constant([
|
||||
image_tokens, image_tokens + 1, 2 * (image_tokens) + action_tokens,
|
||||
2 * (image_tokens) + action_tokens + 1
|
||||
], tf.int32))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Useful to enable if running with ipdb.
|
||||
tf.config.run_functions_eagerly(True)
|
||||
tf.test.main()
|
|
@ -0,0 +1,391 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for networks."""
|
||||
|
||||
import copy
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from robotics_transformer import sequence_agent
|
||||
from robotics_transformer import transformer_network
|
||||
from tensor2robot.utils import tensorspec_utils
|
||||
import tensorflow as tf
|
||||
from tf_agents.specs import tensor_spec
|
||||
from tf_agents.trajectories import time_step as ts
|
||||
|
||||
BATCH_SIZE = 2
|
||||
TIME_SEQUENCE_LENGTH = 3
|
||||
HEIGHT = 256
|
||||
WIDTH = 320
|
||||
NUM_IMAGE_TOKENS = 2
|
||||
|
||||
|
||||
def spec_names_list() -> list[str]:
|
||||
"""Lists the different types of specs accepted by the transformer."""
|
||||
return ['default']
|
||||
|
||||
|
||||
def state_spec_list() -> list[tensorspec_utils.TensorSpecStruct]:
|
||||
"""Lists the different types of state spec accepted by the transformer."""
|
||||
state_spec = tensorspec_utils.TensorSpecStruct()
|
||||
state_spec.image = tensor_spec.BoundedTensorSpec([HEIGHT, WIDTH, 3],
|
||||
dtype=tf.float32,
|
||||
name='image',
|
||||
minimum=0.,
|
||||
maximum=1.)
|
||||
state_spec.natural_language_embedding = tensor_spec.TensorSpec(
|
||||
shape=[512], dtype=tf.float32, name='natural_language_embedding')
|
||||
|
||||
state_spec_mask = copy.deepcopy(state_spec)
|
||||
state_spec_mask.initial_binary_mask = tensor_spec.BoundedTensorSpec(
|
||||
[HEIGHT, WIDTH, 1],
|
||||
dtype=tf.int32,
|
||||
name='initial_binary_mask',
|
||||
minimum=0,
|
||||
maximum=255)
|
||||
|
||||
state_spec_tcl = copy.deepcopy(state_spec)
|
||||
state_spec_tcl.original_image = tensor_spec.BoundedTensorSpec(
|
||||
[HEIGHT, WIDTH, 3],
|
||||
dtype=tf.float32,
|
||||
name='original_image',
|
||||
minimum=0.,
|
||||
maximum=1.)
|
||||
|
||||
return [
|
||||
state_spec,
|
||||
state_spec_mask,
|
||||
state_spec_tcl,
|
||||
]
|
||||
|
||||
|
||||
def observations_list(training: bool = True) -> list[dict[str, tf.Tensor]]:
|
||||
"""Lists the different types of observations accepted by the transformer."""
|
||||
if training:
|
||||
image_shape = [BATCH_SIZE, TIME_SEQUENCE_LENGTH, HEIGHT, WIDTH, 3]
|
||||
emb_shape = [BATCH_SIZE, TIME_SEQUENCE_LENGTH, 512]
|
||||
mask_shape = [BATCH_SIZE, TIME_SEQUENCE_LENGTH, HEIGHT, WIDTH, 1]
|
||||
else:
|
||||
# inference currently only support batch size of 1
|
||||
image_shape = [1, HEIGHT, WIDTH, 3]
|
||||
emb_shape = [1, 512]
|
||||
mask_shape = [1, HEIGHT, WIDTH, 1]
|
||||
return [
|
||||
{
|
||||
'image': tf.constant(0.5, shape=image_shape),
|
||||
'natural_language_embedding': tf.constant(1., shape=emb_shape),
|
||||
},
|
||||
{
|
||||
'image': tf.constant(0.5, shape=image_shape),
|
||||
'natural_language_embedding': tf.constant(1., shape=emb_shape),
|
||||
'initial_binary_mask': tf.constant(192, shape=mask_shape),
|
||||
},
|
||||
{ # This is used for TCL.
|
||||
'image': tf.constant(0.5, shape=image_shape),
|
||||
'original_image': tf.constant(0.4, shape=image_shape),
|
||||
'natural_language_embedding': tf.constant(1., shape=emb_shape),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
NAME_TO_STATE_SPECS = dict(zip(spec_names_list(), state_spec_list()))
|
||||
NAME_TO_OBSERVATIONS = dict(zip(spec_names_list(), observations_list()))
|
||||
NAME_TO_INF_OBSERVATIONS = dict(
|
||||
zip(spec_names_list(), observations_list(False)))
|
||||
|
||||
|
||||
class FakeImageTokenizer(tf.keras.layers.Layer):
|
||||
"""Fake Image Tokenizer for testing Transformer."""
|
||||
|
||||
def __init__(self,
|
||||
encoder: ...,
|
||||
position_embedding: ...,
|
||||
embedding_output_dim: int,
|
||||
patch_size: int,
|
||||
use_token_learner: bool = False,
|
||||
num_tokens: int = NUM_IMAGE_TOKENS,
|
||||
use_initial_binary_mask: bool = False,
|
||||
**kwargs):
|
||||
del encoder, position_embedding, patch_size, use_token_learner
|
||||
super().__init__(**kwargs)
|
||||
self.tokens_per_context_image = num_tokens
|
||||
if use_initial_binary_mask:
|
||||
self.tokens_per_context_image += 1
|
||||
self.embedding_output_dim = embedding_output_dim
|
||||
self.use_initial_binary_mask = use_initial_binary_mask
|
||||
|
||||
def __call__(self,
|
||||
image: tf.Tensor,
|
||||
context: Optional[tf.Tensor] = None,
|
||||
initial_binary_mask: Optional[tf.Tensor] = None,
|
||||
training: bool = False) -> tf.Tensor:
|
||||
if self.use_initial_binary_mask:
|
||||
assert initial_binary_mask is not None
|
||||
image_shape = tf.shape(image)
|
||||
seq_size = image_shape[1]
|
||||
batch_size = image_shape[0]
|
||||
all_tokens = []
|
||||
num_tokens = self.tokens_per_context_image
|
||||
for t in range(seq_size):
|
||||
tokens = tf.ones([batch_size, 1, num_tokens, self.embedding_output_dim
|
||||
]) * image[0][t][0][0]
|
||||
all_tokens.append(tokens)
|
||||
return tf.concat(all_tokens, axis=1)
|
||||
|
||||
|
||||
class TransformerNetworkTestUtils(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Defines specs, SequenceAgent, and various other testing utilities."""
|
||||
|
||||
def _define_specs(self,
|
||||
train_batch_size=BATCH_SIZE,
|
||||
inference_batch_size=1,
|
||||
time_sequence_length=TIME_SEQUENCE_LENGTH,
|
||||
inference_sequence_length=TIME_SEQUENCE_LENGTH,
|
||||
token_embedding_size=512,
|
||||
image_width=WIDTH,
|
||||
image_height=HEIGHT):
|
||||
"""Defines specs and observations (both training and inference)."""
|
||||
self.train_batch_size = train_batch_size
|
||||
self.inference_batch_size = inference_batch_size
|
||||
self.time_sequence_length = time_sequence_length
|
||||
self.inference_sequence_length = inference_sequence_length
|
||||
self.token_embedding_size = token_embedding_size
|
||||
action_spec = tensorspec_utils.TensorSpecStruct()
|
||||
action_spec.world_vector = tensor_spec.BoundedTensorSpec(
|
||||
(3,), dtype=tf.float32, minimum=-1., maximum=1., name='world_vector')
|
||||
|
||||
action_spec.rotation_delta = tensor_spec.BoundedTensorSpec(
|
||||
(3,),
|
||||
dtype=tf.float32,
|
||||
minimum=-np.pi / 2,
|
||||
maximum=np.pi / 2,
|
||||
name='rotation_delta')
|
||||
|
||||
action_spec.gripper_closedness_action = tensor_spec.BoundedTensorSpec(
|
||||
(1,),
|
||||
dtype=tf.float32,
|
||||
minimum=-1.,
|
||||
maximum=1.,
|
||||
name='gripper_closedness_action')
|
||||
action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
|
||||
(2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')
|
||||
|
||||
state_spec = tensorspec_utils.TensorSpecStruct()
|
||||
state_spec.image = tensor_spec.BoundedTensorSpec(
|
||||
[image_height, image_width, 3],
|
||||
dtype=tf.float32,
|
||||
name='image',
|
||||
minimum=0.,
|
||||
maximum=1.)
|
||||
state_spec.natural_language_embedding = tensor_spec.TensorSpec(
|
||||
shape=[self.token_embedding_size],
|
||||
dtype=tf.float32,
|
||||
name='natural_language_embedding')
|
||||
self._policy_info_spec = {
|
||||
'return':
|
||||
tensor_spec.BoundedTensorSpec((),
|
||||
dtype=tf.float32,
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
name='return'),
|
||||
'discounted_return':
|
||||
tensor_spec.BoundedTensorSpec((),
|
||||
dtype=tf.float32,
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
name='discounted_return'),
|
||||
}
|
||||
|
||||
self._state_spec = state_spec
|
||||
self._action_spec = action_spec
|
||||
|
||||
self._inference_observation = {
|
||||
'image':
|
||||
tf.constant(
|
||||
1,
|
||||
shape=[self.inference_batch_size, image_height, image_width, 3],
|
||||
dtype=tf.dtypes.float32),
|
||||
'natural_language_embedding':
|
||||
tf.constant(
|
||||
1.,
|
||||
shape=[self.inference_batch_size, self.token_embedding_size],
|
||||
dtype=tf.dtypes.float32),
|
||||
}
|
||||
self._train_observation = {
|
||||
'image':
|
||||
tf.constant(
|
||||
0.5,
|
||||
shape=[
|
||||
self.train_batch_size, self.time_sequence_length,
|
||||
image_height, image_width, 3
|
||||
]),
|
||||
'natural_language_embedding':
|
||||
tf.constant(
|
||||
1.,
|
||||
shape=[
|
||||
self.train_batch_size, self.time_sequence_length,
|
||||
self.token_embedding_size
|
||||
]),
|
||||
}
|
||||
self._inference_action = {
|
||||
'world_vector':
|
||||
tf.constant(0.5, shape=[self.inference_batch_size, 3]),
|
||||
'rotation_delta':
|
||||
tf.constant(0.5, shape=[self.inference_batch_size, 3]),
|
||||
'terminate_episode':
|
||||
tf.constant(
|
||||
[0, 1] * self.inference_batch_size,
|
||||
shape=[self.inference_batch_size, 2]),
|
||||
'gripper_closedness_action':
|
||||
tf.constant(0.5, shape=[self.inference_batch_size, 1]),
|
||||
}
|
||||
self._train_action = {
|
||||
'world_vector':
|
||||
tf.constant(
|
||||
0.5,
|
||||
shape=[self.train_batch_size, self.time_sequence_length, 3]),
|
||||
'rotation_delta':
|
||||
tf.constant(
|
||||
0.5,
|
||||
shape=[self.train_batch_size, self.time_sequence_length, 3]),
|
||||
'terminate_episode':
|
||||
tf.constant(
|
||||
[0, 1] * self.train_batch_size * self.time_sequence_length,
|
||||
shape=[self.train_batch_size, self.time_sequence_length, 2]),
|
||||
'gripper_closedness_action':
|
||||
tf.constant(
|
||||
0.5,
|
||||
shape=[self.train_batch_size, self.time_sequence_length, 1]),
|
||||
}
|
||||
|
||||
def _create_agent(self, actor_network=None):
|
||||
"""Creates SequenceAgent using custom actor_network."""
|
||||
time_step_spec = ts.time_step_spec(observation_spec=self._state_spec)
|
||||
if actor_network is None:
|
||||
actor_network = transformer_network.TransformerNetwork
|
||||
|
||||
self._agent = sequence_agent.SequenceAgent(
|
||||
time_step_spec=time_step_spec,
|
||||
action_spec=self._action_spec,
|
||||
actor_network=actor_network,
|
||||
actor_optimizer=tf.keras.optimizers.Adam(),
|
||||
train_step_counter=tf.compat.v1.train.get_or_create_global_step(),
|
||||
time_sequence_length=TIME_SEQUENCE_LENGTH)
|
||||
self._num_action_tokens = (
|
||||
# pylint:disable=protected-access
|
||||
self._agent._actor_network._action_tokenizer._tokens_per_action)
|
||||
# pylint:enable=protected-access
|
||||
|
||||
def setUp(self):
|
||||
self._define_specs()
|
||||
super().setUp()
|
||||
|
||||
def get_image_value(self, step_idx: int) -> float:
|
||||
return float(step_idx) / self.time_sequence_length
|
||||
|
||||
def get_action_logits(self, batch_size: int, value: int,
|
||||
vocab_size: int) -> tf.Tensor:
|
||||
return tf.broadcast_to(
|
||||
tf.one_hot(value % vocab_size, vocab_size)[tf.newaxis, tf.newaxis, :],
|
||||
[batch_size, 1, vocab_size])
|
||||
|
||||
def create_obs(self, value) -> dict[str, tf.Tensor]:
|
||||
observations = {}
|
||||
observations['image'] = value * self._inference_observation['image']
|
||||
observations[
|
||||
'natural_language_embedding'] = value * self._inference_observation[
|
||||
'natural_language_embedding']
|
||||
return observations
|
||||
|
||||
def fake_action_token_emb(self, action_tokens) -> tf.Tensor:
|
||||
"""Just pad with zeros."""
|
||||
shape = action_tokens.shape
|
||||
assert self.vocab_size > self.token_embedding_size
|
||||
assert len(shape) == 4
|
||||
return action_tokens[:, :, :, :self.token_embedding_size]
|
||||
|
||||
def fake_transformer(
|
||||
self, all_tokens, training,
|
||||
attention_mask) -> Union[tf.Tensor, Tuple[tf.Tensor, list[tf.Tensor]]]:
|
||||
"""Fakes the call to TransformerNetwork._transformer."""
|
||||
del training
|
||||
del attention_mask
|
||||
# We expect ST00 ST01 A00 A01...
|
||||
# Where:
|
||||
# * ST01 is token 1 of state 0.
|
||||
# * A01 is token 1 of action 0.
|
||||
shape = all_tokens.shape.as_list()
|
||||
batch_size = shape[0]
|
||||
self.assertEqual(batch_size, 1)
|
||||
emb_size = self.token_embedding_size
|
||||
|
||||
# transform to [batch_size, num_tokens, token_size]
|
||||
all_tokens = tf.reshape(all_tokens, [batch_size, -1, emb_size])
|
||||
# Pads tokens to be of vocab_size.
|
||||
self.assertGreater(self.vocab_size, self.token_embedding_size)
|
||||
all_shape = all_tokens.shape
|
||||
self.assertLen(all_shape.as_list(), 3)
|
||||
output_tokens = tf.concat([
|
||||
all_tokens,
|
||||
tf.zeros([
|
||||
all_shape[0], all_shape[1],
|
||||
self.vocab_size - self.token_embedding_size
|
||||
])
|
||||
],
|
||||
axis=-1)
|
||||
num_tokens_per_step = NUM_IMAGE_TOKENS + self._num_action_tokens
|
||||
# Check state/action alignment.
|
||||
window_range = min(self._step_idx + 1, self.time_sequence_length)
|
||||
for j in range(window_range):
|
||||
# The index step that is stored in j = 0.
|
||||
first_step_idx = max(0, self._step_idx + 1 - self.time_sequence_length)
|
||||
image_idx = j * num_tokens_per_step
|
||||
action_start_index = image_idx + NUM_IMAGE_TOKENS
|
||||
for t in range(NUM_IMAGE_TOKENS):
|
||||
self.assertAllEqual(
|
||||
self.get_image_value(first_step_idx + j) *
|
||||
tf.ones_like(all_tokens[0][image_idx][:self.token_embedding_size]),
|
||||
all_tokens[0][image_idx + t][:self.token_embedding_size])
|
||||
# if j is not the current step in the window, all action dimensions
|
||||
# from previous steps are already infered and thus can be checked.
|
||||
action_dims_range = self.action_inf_idx if j == window_range - 1 else self._num_action_tokens
|
||||
for t in range(action_dims_range):
|
||||
token_idx = action_start_index + t
|
||||
action_value = (first_step_idx + j) * self._num_action_tokens + t
|
||||
self.assertAllEqual(
|
||||
self.get_action_logits(
|
||||
batch_size=batch_size,
|
||||
value=action_value,
|
||||
vocab_size=self.vocab_size)[0][0][:self.token_embedding_size],
|
||||
all_tokens[0][token_idx][:self.token_embedding_size])
|
||||
# Output the right action dimension value.
|
||||
image_token_index = (
|
||||
min(self._step_idx, self.time_sequence_length - 1) *
|
||||
num_tokens_per_step)
|
||||
transformer_shift = -1
|
||||
action_index = (
|
||||
image_token_index + NUM_IMAGE_TOKENS + self.action_inf_idx +
|
||||
transformer_shift)
|
||||
action_value = self._step_idx * self._num_action_tokens + self.action_inf_idx
|
||||
action_logits = self.get_action_logits(
|
||||
batch_size=batch_size, value=action_value, vocab_size=self.vocab_size)
|
||||
output_tokens = tf.concat([
|
||||
output_tokens[:, :action_index, :], action_logits[:, :, :],
|
||||
output_tokens[:, action_index + 1:, :]
|
||||
],
|
||||
axis=1)
|
||||
self.action_inf_idx = (self.action_inf_idx + 1) % self._num_action_tokens
|
||||
attention_scores = []
|
||||
return output_tokens, attention_scores
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for transformer."""
|
||||
from absl.testing import parameterized
|
||||
from robotics_transformer import transformer
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class TransformerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._vocab_size = 10
|
||||
batch_size = 8
|
||||
sequence_len = 12
|
||||
self._tokens = tf.random.uniform(
|
||||
[batch_size, sequence_len, self._vocab_size],
|
||||
minval=0,
|
||||
maxval=1,
|
||||
dtype=tf.dtypes.float32,
|
||||
)
|
||||
super(TransformerTest, self).setUp()
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
def test_transformer_forwardpass(self, return_attention_scores):
|
||||
network = transformer.Transformer(
|
||||
num_layers=2,
|
||||
layer_size=512,
|
||||
num_heads=4,
|
||||
feed_forward_size=256,
|
||||
dropout_rate=0.1,
|
||||
vocab_size=self._vocab_size,
|
||||
return_attention_scores=return_attention_scores)
|
||||
|
||||
output_tokens, attention_scores = network(self._tokens, attention_mask=None)
|
||||
self.assertSequenceEqual(self._tokens.shape.as_list(),
|
||||
output_tokens.shape.as_list())
|
||||
if return_attention_scores:
|
||||
self.assertNotEmpty(attention_scores)
|
||||
else:
|
||||
self.assertEmpty(attention_scores)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in New Issue