forked from jiuyuan/InfiniTensor
commit
eda41b06a7
|
@ -0,0 +1,90 @@
|
||||||
|
---
|
||||||
|
Language: Cpp
|
||||||
|
# BasedOnStyle: LLVM
|
||||||
|
AccessModifierOffset: -2
|
||||||
|
AlignAfterOpenBracket: Align
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
AlignEscapedNewlinesLeft: false
|
||||||
|
AlignOperands: true
|
||||||
|
AlignTrailingComments: true
|
||||||
|
AllowAllParametersOfDeclarationOnNextLine: true
|
||||||
|
AllowShortBlocksOnASingleLine: false
|
||||||
|
AllowShortCaseLabelsOnASingleLine: false
|
||||||
|
AllowShortFunctionsOnASingleLine: All
|
||||||
|
AllowShortIfStatementsOnASingleLine: false
|
||||||
|
AllowShortLoopsOnASingleLine: false
|
||||||
|
AlwaysBreakAfterDefinitionReturnType: None
|
||||||
|
AlwaysBreakAfterReturnType: None
|
||||||
|
AlwaysBreakBeforeMultilineStrings: false
|
||||||
|
AlwaysBreakTemplateDeclarations: false
|
||||||
|
BinPackArguments: true
|
||||||
|
BinPackParameters: true
|
||||||
|
BraceWrapping:
|
||||||
|
AfterClass: false
|
||||||
|
AfterControlStatement: false
|
||||||
|
AfterEnum: false
|
||||||
|
AfterFunction: false
|
||||||
|
AfterNamespace: false
|
||||||
|
AfterObjCDeclaration: false
|
||||||
|
AfterStruct: false
|
||||||
|
AfterUnion: false
|
||||||
|
BeforeCatch: false
|
||||||
|
BeforeElse: false
|
||||||
|
IndentBraces: false
|
||||||
|
BreakBeforeBinaryOperators: None
|
||||||
|
BreakBeforeBraces: Attach
|
||||||
|
BreakBeforeTernaryOperators: true
|
||||||
|
BreakConstructorInitializersBeforeComma: false
|
||||||
|
ColumnLimit: 80
|
||||||
|
CommentPragmas: '^ IWYU pragma:'
|
||||||
|
ConstructorInitializerAllOnOneLineOrOnePerLine: false
|
||||||
|
ConstructorInitializerIndentWidth: 4
|
||||||
|
ContinuationIndentWidth: 4
|
||||||
|
Cpp11BracedListStyle: true
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
DisableFormat: false
|
||||||
|
ExperimentalAutoDetectBinPacking: false
|
||||||
|
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '^(<|"(gtest|isl|json)/)'
|
||||||
|
Priority: 3
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 1
|
||||||
|
IndentCaseLabels: false
|
||||||
|
IndentWidth: 4
|
||||||
|
IndentWrappedFunctionNames: false
|
||||||
|
KeepEmptyLinesAtTheStartOfBlocks: true
|
||||||
|
MacroBlockBegin: ''
|
||||||
|
MacroBlockEnd: ''
|
||||||
|
MaxEmptyLinesToKeep: 1
|
||||||
|
NamespaceIndentation: None
|
||||||
|
ObjCBlockIndentWidth: 2
|
||||||
|
ObjCSpaceAfterProperty: false
|
||||||
|
ObjCSpaceBeforeProtocolList: true
|
||||||
|
PenaltyBreakBeforeFirstCallParameter: 19
|
||||||
|
PenaltyBreakComment: 300
|
||||||
|
PenaltyBreakFirstLessLess: 120
|
||||||
|
PenaltyBreakString: 1000
|
||||||
|
PenaltyExcessCharacter: 1000000
|
||||||
|
PenaltyReturnTypeOnItsOwnLine: 60
|
||||||
|
PointerAlignment: Right
|
||||||
|
ReflowComments: true
|
||||||
|
SortIncludes: true
|
||||||
|
SpaceAfterCStyleCast: false
|
||||||
|
SpaceBeforeAssignmentOperators: true
|
||||||
|
SpaceBeforeParens: ControlStatements
|
||||||
|
SpaceInEmptyParentheses: false
|
||||||
|
SpacesBeforeTrailingComments: 1
|
||||||
|
SpacesInAngles: false
|
||||||
|
SpacesInContainerLiterals: true
|
||||||
|
SpacesInCStyleCastParentheses: false
|
||||||
|
SpacesInParentheses: false
|
||||||
|
SpacesInSquareBrackets: false
|
||||||
|
Standard: Cpp11
|
||||||
|
TabWidth: 8
|
||||||
|
UseTab: Never
|
||||||
|
...
|
||||||
|
|
|
@ -0,0 +1,311 @@
|
||||||
|
{
|
||||||
|
"_help_parse": "Options affecting listfile parsing",
|
||||||
|
"parse": {
|
||||||
|
"_help_additional_commands": [
|
||||||
|
"Specify structure for custom cmake functions"
|
||||||
|
],
|
||||||
|
"additional_commands": {
|
||||||
|
"foo": {
|
||||||
|
"flags": [
|
||||||
|
"BAR",
|
||||||
|
"BAZ"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"HEADERS": "*",
|
||||||
|
"SOURCES": "*",
|
||||||
|
"DEPENDS": "*"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_help_override_spec": [
|
||||||
|
"Override configurations per-command where available"
|
||||||
|
],
|
||||||
|
"override_spec": {},
|
||||||
|
"_help_vartags": [
|
||||||
|
"Specify variable tags."
|
||||||
|
],
|
||||||
|
"vartags": [],
|
||||||
|
"_help_proptags": [
|
||||||
|
"Specify property tags."
|
||||||
|
],
|
||||||
|
"proptags": []
|
||||||
|
},
|
||||||
|
"_help_format": "Options affecting formatting.",
|
||||||
|
"format": {
|
||||||
|
"_help_disable": [
|
||||||
|
"Disable formatting entirely, making cmake-format a no-op"
|
||||||
|
],
|
||||||
|
"disable": false,
|
||||||
|
"_help_line_width": [
|
||||||
|
"How wide to allow formatted cmake files"
|
||||||
|
],
|
||||||
|
"line_width": 1500,
|
||||||
|
"_help_tab_size": [
|
||||||
|
"How many spaces to tab for indent"
|
||||||
|
],
|
||||||
|
"tab_size": 2,
|
||||||
|
"_help_use_tabchars": [
|
||||||
|
"If true, lines are indented using tab characters (utf-8",
|
||||||
|
"0x09) instead of <tab_size> space characters (utf-8 0x20).",
|
||||||
|
"In cases where the layout would require a fractional tab",
|
||||||
|
"character, the behavior of the fractional indentation is",
|
||||||
|
"governed by <fractional_tab_policy>"
|
||||||
|
],
|
||||||
|
"use_tabchars": false,
|
||||||
|
"_help_fractional_tab_policy": [
|
||||||
|
"If <use_tabchars> is True, then the value of this variable",
|
||||||
|
"indicates how fractional indentions are handled during",
|
||||||
|
"whitespace replacement. If set to 'use-space', fractional",
|
||||||
|
"indentation is left as spaces (utf-8 0x20). If set to",
|
||||||
|
"`round-up` fractional indentation is replaced with a single",
|
||||||
|
"tab character (utf-8 0x09) effectively shifting the column",
|
||||||
|
"to the next tabstop"
|
||||||
|
],
|
||||||
|
"fractional_tab_policy": "use-space",
|
||||||
|
"_help_max_subgroups_hwrap": [
|
||||||
|
"If an argument group contains more than this many sub-groups",
|
||||||
|
"(parg or kwarg groups) then force it to a vertical layout."
|
||||||
|
],
|
||||||
|
"max_subgroups_hwrap": 2,
|
||||||
|
"_help_max_pargs_hwrap": [
|
||||||
|
"If a positional argument group contains more than this many",
|
||||||
|
"arguments, then force it to a vertical layout."
|
||||||
|
],
|
||||||
|
"max_pargs_hwrap": 6,
|
||||||
|
"_help_max_rows_cmdline": [
|
||||||
|
"If a cmdline positional group consumes more than this many",
|
||||||
|
"lines without nesting, then invalidate the layout (and nest)"
|
||||||
|
],
|
||||||
|
"max_rows_cmdline": 2,
|
||||||
|
"_help_separate_ctrl_name_with_space": [
|
||||||
|
"If true, separate flow control names from their parentheses",
|
||||||
|
"with a space"
|
||||||
|
],
|
||||||
|
"separate_ctrl_name_with_space": false,
|
||||||
|
"_help_separate_fn_name_with_space": [
|
||||||
|
"If true, separate function names from parentheses with a",
|
||||||
|
"space"
|
||||||
|
],
|
||||||
|
"separate_fn_name_with_space": false,
|
||||||
|
"_help_dangle_parens": [
|
||||||
|
"If a statement is wrapped to more than one line, than dangle",
|
||||||
|
"the closing parenthesis on its own line."
|
||||||
|
],
|
||||||
|
"dangle_parens": false,
|
||||||
|
"_help_dangle_align": [
|
||||||
|
"If the trailing parenthesis must be 'dangled' on its on",
|
||||||
|
"line, then align it to this reference: `prefix`: the start",
|
||||||
|
"of the statement, `prefix-indent`: the start of the",
|
||||||
|
"statement, plus one indentation level, `child`: align to",
|
||||||
|
"the column of the arguments"
|
||||||
|
],
|
||||||
|
"dangle_align": "prefix",
|
||||||
|
"_help_min_prefix_chars": [
|
||||||
|
"If the statement spelling length (including space and",
|
||||||
|
"parenthesis) is smaller than this amount, then force reject",
|
||||||
|
"nested layouts."
|
||||||
|
],
|
||||||
|
"min_prefix_chars": 4,
|
||||||
|
"_help_max_prefix_chars": [
|
||||||
|
"If the statement spelling length (including space and",
|
||||||
|
"parenthesis) is larger than the tab width by more than this",
|
||||||
|
"amount, then force reject un-nested layouts."
|
||||||
|
],
|
||||||
|
"max_prefix_chars": 10,
|
||||||
|
"_help_max_lines_hwrap": [
|
||||||
|
"If a candidate layout is wrapped horizontally but it exceeds",
|
||||||
|
"this many lines, then reject the layout."
|
||||||
|
],
|
||||||
|
"max_lines_hwrap": 2,
|
||||||
|
"_help_line_ending": [
|
||||||
|
"What style line endings to use in the output."
|
||||||
|
],
|
||||||
|
"line_ending": "unix",
|
||||||
|
"_help_command_case": [
|
||||||
|
"Format command names consistently as 'lower' or 'upper' case"
|
||||||
|
],
|
||||||
|
"command_case": "canonical",
|
||||||
|
"_help_keyword_case": [
|
||||||
|
"Format keywords consistently as 'lower' or 'upper' case"
|
||||||
|
],
|
||||||
|
"keyword_case": "unchanged",
|
||||||
|
"_help_always_wrap": [
|
||||||
|
"A list of command names which should always be wrapped"
|
||||||
|
],
|
||||||
|
"always_wrap": [],
|
||||||
|
"_help_enable_sort": [
|
||||||
|
"If true, the argument lists which are known to be sortable",
|
||||||
|
"will be sorted lexicographicall"
|
||||||
|
],
|
||||||
|
"enable_sort": true,
|
||||||
|
"_help_autosort": [
|
||||||
|
"If true, the parsers may infer whether or not an argument",
|
||||||
|
"list is sortable (without annotation)."
|
||||||
|
],
|
||||||
|
"autosort": false,
|
||||||
|
"_help_require_valid_layout": [
|
||||||
|
"By default, if cmake-format cannot successfully fit",
|
||||||
|
"everything into the desired linewidth it will apply the",
|
||||||
|
"last, most aggressive attempt that it made. If this flag is",
|
||||||
|
"True, however, cmake-format will print error, exit with non-",
|
||||||
|
"zero status code, and write-out nothing"
|
||||||
|
],
|
||||||
|
"require_valid_layout": false,
|
||||||
|
"_help_layout_passes": [
|
||||||
|
"A dictionary mapping layout nodes to a list of wrap",
|
||||||
|
"decisions. See the documentation for more information."
|
||||||
|
],
|
||||||
|
"layout_passes": {}
|
||||||
|
},
|
||||||
|
"_help_markup": "Options affecting comment reflow and formatting.",
|
||||||
|
"markup": {
|
||||||
|
"_help_bullet_char": [
|
||||||
|
"What character to use for bulleted lists"
|
||||||
|
],
|
||||||
|
"bullet_char": "*",
|
||||||
|
"_help_enum_char": [
|
||||||
|
"What character to use as punctuation after numerals in an",
|
||||||
|
"enumerated list"
|
||||||
|
],
|
||||||
|
"enum_char": ".",
|
||||||
|
"_help_first_comment_is_literal": [
|
||||||
|
"If comment markup is enabled, don't reflow the first comment",
|
||||||
|
"block in each listfile. Use this to preserve formatting of",
|
||||||
|
"your copyright/license statements."
|
||||||
|
],
|
||||||
|
"first_comment_is_literal": false,
|
||||||
|
"_help_literal_comment_pattern": [
|
||||||
|
"If comment markup is enabled, don't reflow any comment block",
|
||||||
|
"which matches this (regex) pattern. Default is `None`",
|
||||||
|
"(disabled)."
|
||||||
|
],
|
||||||
|
"literal_comment_pattern": ".*",
|
||||||
|
"_help_fence_pattern": [
|
||||||
|
"Regular expression to match preformat fences in comments",
|
||||||
|
"default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``"
|
||||||
|
],
|
||||||
|
"fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$",
|
||||||
|
"_help_ruler_pattern": [
|
||||||
|
"Regular expression to match rulers in comments default=",
|
||||||
|
"``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``"
|
||||||
|
],
|
||||||
|
"ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$",
|
||||||
|
"_help_explicit_trailing_pattern": [
|
||||||
|
"If a comment line matches starts with this pattern then it",
|
||||||
|
"is explicitly a trailing comment for the preceding argument.",
|
||||||
|
"Default is '#<'"
|
||||||
|
],
|
||||||
|
"explicit_trailing_pattern": "#<",
|
||||||
|
"_help_hashruler_min_length": [
|
||||||
|
"If a comment line starts with at least this many consecutive",
|
||||||
|
"hash characters, then don't lstrip() them off. This allows",
|
||||||
|
"for lazy hash rulers where the first hash char is not",
|
||||||
|
"separated by space"
|
||||||
|
],
|
||||||
|
"hashruler_min_length": 10,
|
||||||
|
"_help_canonicalize_hashrulers": [
|
||||||
|
"If true, then insert a space between the first hash char and",
|
||||||
|
"remaining hash chars in a hash ruler, and normalize its",
|
||||||
|
"length to fill the column"
|
||||||
|
],
|
||||||
|
"canonicalize_hashrulers": true,
|
||||||
|
"_help_enable_markup": [
|
||||||
|
"enable comment markup parsing and reflow"
|
||||||
|
],
|
||||||
|
"enable_markup": true
|
||||||
|
},
|
||||||
|
"_help_lint": "Options affecting the linter",
|
||||||
|
"lint": {
|
||||||
|
"_help_disabled_codes": [
|
||||||
|
"a list of lint codes to disable"
|
||||||
|
],
|
||||||
|
"disabled_codes": [],
|
||||||
|
"_help_function_pattern": [
|
||||||
|
"regular expression pattern describing valid function names"
|
||||||
|
],
|
||||||
|
"function_pattern": "[0-9a-z_]+",
|
||||||
|
"_help_macro_pattern": [
|
||||||
|
"regular expression pattern describing valid macro names"
|
||||||
|
],
|
||||||
|
"macro_pattern": "[0-9A-Z_]+",
|
||||||
|
"_help_global_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with global (cache) scope"
|
||||||
|
],
|
||||||
|
"global_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_internal_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with global scope (but internal semantic)"
|
||||||
|
],
|
||||||
|
"internal_var_pattern": "_[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_local_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with local scope"
|
||||||
|
],
|
||||||
|
"local_var_pattern": "[a-z][a-z0-9_]+",
|
||||||
|
"_help_private_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"privatedirectory variables"
|
||||||
|
],
|
||||||
|
"private_var_pattern": "_[0-9a-z_]+",
|
||||||
|
"_help_public_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for public",
|
||||||
|
"directory variables"
|
||||||
|
],
|
||||||
|
"public_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_argument_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"function/macro arguments and loop variables."
|
||||||
|
],
|
||||||
|
"argument_var_pattern": "[a-z][a-z0-9_]+",
|
||||||
|
"_help_keyword_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"keywords used in functions or macros"
|
||||||
|
],
|
||||||
|
"keyword_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_max_conditionals_custom_parser": [
|
||||||
|
"In the heuristic for C0201, how many conditionals to match",
|
||||||
|
"within a loop in before considering the loop a parser."
|
||||||
|
],
|
||||||
|
"max_conditionals_custom_parser": 2,
|
||||||
|
"_help_min_statement_spacing": [
|
||||||
|
"Require at least this many newlines between statements"
|
||||||
|
],
|
||||||
|
"min_statement_spacing": 1,
|
||||||
|
"_help_max_statement_spacing": [
|
||||||
|
"Require no more than this many newlines between statements"
|
||||||
|
],
|
||||||
|
"max_statement_spacing": 2,
|
||||||
|
"max_returns": 6,
|
||||||
|
"max_branches": 12,
|
||||||
|
"max_arguments": 5,
|
||||||
|
"max_localvars": 15,
|
||||||
|
"max_statements": 50
|
||||||
|
},
|
||||||
|
"_help_encode": "Options affecting file encoding",
|
||||||
|
"encode": {
|
||||||
|
"_help_emit_byteorder_mark": [
|
||||||
|
"If true, emit the unicode byte-order mark (BOM) at the start",
|
||||||
|
"of the file"
|
||||||
|
],
|
||||||
|
"emit_byteorder_mark": false,
|
||||||
|
"_help_input_encoding": [
|
||||||
|
"Specify the encoding of the input file. Defaults to utf-8"
|
||||||
|
],
|
||||||
|
"input_encoding": "utf-8",
|
||||||
|
"_help_output_encoding": [
|
||||||
|
"Specify the encoding of the output file. Defaults to utf-8.",
|
||||||
|
"Note that cmake only claims to support utf-8 so be careful",
|
||||||
|
"when using anything else"
|
||||||
|
],
|
||||||
|
"output_encoding": "utf-8"
|
||||||
|
},
|
||||||
|
"_help_misc": "Miscellaneous configurations options.",
|
||||||
|
"misc": {
|
||||||
|
"_help_per_command": [
|
||||||
|
"A dictionary containing any per-command configuration",
|
||||||
|
"overrides. Currently only `command_case` is supported."
|
||||||
|
],
|
||||||
|
"per_command": {}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
name: clang-format Check
|
||||||
|
on: [pull_request]
|
||||||
|
jobs:
|
||||||
|
formatting-check:
|
||||||
|
name: Formatting Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
path:
|
||||||
|
- 'include'
|
||||||
|
- 'src'
|
||||||
|
- 'test'
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Run clang-format style check for C/C++/Protobuf programs.
|
||||||
|
uses: jidicula/clang-format-action@v4.8.0
|
||||||
|
with:
|
||||||
|
clang-format-version: '14'
|
||||||
|
check-path: ${{ matrix.path }}
|
|
@ -30,3 +30,8 @@
|
||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
|
build/
|
||||||
|
build_debug/
|
||||||
|
|
||||||
|
.vscode/
|
|
@ -0,0 +1,9 @@
|
||||||
|
[submodule "3rd-party/pybind11"]
|
||||||
|
path = 3rd-party/pybind11
|
||||||
|
url = git@github.com:pybind/pybind11.git
|
||||||
|
[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"]
|
||||||
|
path = 3rd-party/nlohmann_json_cmake_fetchcontent
|
||||||
|
url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git
|
||||||
|
[submodule "3rd-party/googletest"]
|
||||||
|
path = 3rd-party/googletest
|
||||||
|
url = git@github.com:google/googletest.git
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae
|
|
@ -0,0 +1,91 @@
|
||||||
|
# TODO: check the minimum cmake version
|
||||||
|
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
|
||||||
|
include(CMakeDependentOption)
|
||||||
|
project(InfiniTensor C CXX)
|
||||||
|
|
||||||
|
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||||
|
option(BUILD_TEST "Build tests" ON)
|
||||||
|
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||||
|
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||||
|
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
||||||
|
|
||||||
|
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||||
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||||
|
|
||||||
|
find_package(
|
||||||
|
Python
|
||||||
|
COMPONENTS Interpreter Development
|
||||||
|
REQUIRED)
|
||||||
|
find_package(CUDA REQUIRED)
|
||||||
|
# OpenMP
|
||||||
|
find_package(OpenMP)
|
||||||
|
if(OpenMP_C_FOUND)
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||||
|
endif()
|
||||||
|
if(OpenMP_CXX_FOUND)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include_directories(include)
|
||||||
|
|
||||||
|
# # Pybind11
|
||||||
|
# add_subdirectory(3rd-party/pybind11)
|
||||||
|
# include_directories(3rd-party/pybind11/include)
|
||||||
|
|
||||||
|
# nlohmann_json
|
||||||
|
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||||
|
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||||
|
|
||||||
|
if(BUILD_TEST)
|
||||||
|
set(BUILD_GMOCK
|
||||||
|
OFF
|
||||||
|
CACHE BOOL "Do not build gmock" FORCE)
|
||||||
|
set(INSTALL_GTEST
|
||||||
|
OFF
|
||||||
|
CACHE BOOL "Do not install gtest" FORCE)
|
||||||
|
add_subdirectory(3rd-party/googletest)
|
||||||
|
include_directories(3rd-party/googletest/googletest/include)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
|
||||||
|
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
|
||||||
|
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
|
||||||
|
|
||||||
|
add_library(InfiniTensor SHARED ${SRC})
|
||||||
|
# Target
|
||||||
|
# cuda_add_library(it SHARED ${SRC})
|
||||||
|
# cuda_add_cublas_to_target(it) # cublas
|
||||||
|
# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed)
|
||||||
|
|
||||||
|
# # Python bindings
|
||||||
|
# pybind11_add_module(infini MODULE ${FFI})
|
||||||
|
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||||
|
|
||||||
|
function(build_test files)
|
||||||
|
# Non-recursive glob for skip failed tests
|
||||||
|
file(GLOB TEST_SOURCES ${files})
|
||||||
|
foreach(testsourcefile ${TEST_SOURCES})
|
||||||
|
get_filename_component(testname ${testsourcefile} NAME_WE)
|
||||||
|
add_executable(${testname} ${testsourcefile})
|
||||||
|
target_link_libraries(${testname} InfiniTensor GTest::gtest_main)
|
||||||
|
add_test(NAME ${testname} COMMAND ${testname})
|
||||||
|
endforeach(testsourcefile ${TEST_SOURCES})
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
if(BUILD_TEST)
|
||||||
|
enable_testing()
|
||||||
|
if(BUILD_TEST_CORE)
|
||||||
|
build_test(test/core/*.cc)
|
||||||
|
endif()
|
||||||
|
if(BUILD_TEST_PET)
|
||||||
|
build_test(test/pet/*.cc)
|
||||||
|
endif()
|
||||||
|
if(BUILD_TEST_EINNET)
|
||||||
|
build_test(test/nnet/*.cc)
|
||||||
|
endif()
|
||||||
|
endif()
|
|
@ -0,0 +1,61 @@
|
||||||
|
#pragma once
|
||||||
|
#include <cassert>
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <optional>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
using std::list;
|
||||||
|
using std::map;
|
||||||
|
using std::pair;
|
||||||
|
using std::set;
|
||||||
|
using std::string;
|
||||||
|
using std::tie;
|
||||||
|
using std::to_string;
|
||||||
|
using std::tuple;
|
||||||
|
using std::unordered_map;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
// Aliases
|
||||||
|
using dtype = float;
|
||||||
|
using HashType = size_t; // compatible with std::hash
|
||||||
|
|
||||||
|
// Metaprogramming utilities
|
||||||
|
#define _CAT(A, B) A##B
|
||||||
|
#define _SELECT(NAME, NUM) _CAT(NAME##_, NUM)
|
||||||
|
#define _GET_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ...) COUNT
|
||||||
|
#define _VA_SIZE(...) _GET_COUNT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1)
|
||||||
|
#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__)
|
||||||
|
|
||||||
|
// Assert: conditions should have no side effect
|
||||||
|
#define _IT_ASSERT_2(name, info) \
|
||||||
|
(static_cast<bool>(name) \
|
||||||
|
? void(0) \
|
||||||
|
: throw std::runtime_error( \
|
||||||
|
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
||||||
|
"] Assertion failed (" + #name + "): " + #info))
|
||||||
|
#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, "");
|
||||||
|
|
||||||
|
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
|
||||||
|
#define IT_TODO_HALT() IT_ASSERT(false, "Unimplemented")
|
||||||
|
#define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__)
|
||||||
|
|
||||||
|
// Other utilities
|
||||||
|
|
||||||
|
// std::to_underlying is avaiable since C++23
|
||||||
|
template <typename T> auto enum_to_underlying(T e) {
|
||||||
|
return static_cast<std::underlying_type_t<T>>(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
double timeit(const std::function<void()> &func);
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,47 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
#include "core/tensor.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
// TODO: graph should be attached to a context
|
||||||
|
class GraphNode : public Object {
|
||||||
|
protected:
|
||||||
|
TensorVec tensors;
|
||||||
|
TensorVec inputs;
|
||||||
|
TensorVec outputs;
|
||||||
|
OpVec ops;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Graph(OpVec oplist);
|
||||||
|
string toString() const override;
|
||||||
|
|
||||||
|
void addOp(Operator op) { ops.push_back(op); };
|
||||||
|
const TensorVec &getTensors() const { return tensors; }
|
||||||
|
const TensorVec &getInputs() const { return inputs; }
|
||||||
|
const TensorVec &getOutputs() const { return outputs; }
|
||||||
|
const OpVec &getOperators() const { return ops; }
|
||||||
|
// TensorVec &getInputs();
|
||||||
|
// TensorVec &getOutputs();
|
||||||
|
|
||||||
|
Tensor addTensor(Shape dim, DataType dtype = DataType::Int32) {
|
||||||
|
Tensor tensor = make_ref<TensorNode>(dim, dtype);
|
||||||
|
tensors.emplace_back(tensor);
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dataMalloc();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// TODO: updateConnection
|
||||||
|
/**
|
||||||
|
* @brief Add reverse connections and Op relationship in ctor.
|
||||||
|
*/
|
||||||
|
void updateConnection();
|
||||||
|
|
||||||
|
// TODO: move to another class
|
||||||
|
// bool exportOnnx(const char *path);
|
||||||
|
// bool importOnnx(const char *net);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,76 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "core/operator.h"
|
||||||
|
#include "core/tensor.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
struct PerfRecord {
|
||||||
|
double time; // in milliseconds
|
||||||
|
};
|
||||||
|
|
||||||
|
class Kernel {
|
||||||
|
public:
|
||||||
|
Kernel() {}
|
||||||
|
virtual ~Kernel() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param op The operator to be executed.
|
||||||
|
* @param record The parameters for kernel execution. If extra parameters
|
||||||
|
* are required, inherit from PerfRecord and add extra parameters.
|
||||||
|
* Otherwire, use PerfRecord directly.
|
||||||
|
*/
|
||||||
|
virtual void compute(const Operator &op,
|
||||||
|
const PerfRecord &record) const = 0;
|
||||||
|
/**
|
||||||
|
* @brief Executes an op with a default parameter.
|
||||||
|
*/
|
||||||
|
virtual void compute(const Operator &op) const = 0;
|
||||||
|
// Premise: op is idempotent since it is called multiple times.
|
||||||
|
virtual PerfRecord tune(const Operator &op) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class KernelRegistry {
|
||||||
|
public:
|
||||||
|
using KernelRecord =
|
||||||
|
tuple<Kernel *const, const string, const int>; // Kernel, name, ID
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<KernelAttrs, KernelRecord> kernels;
|
||||||
|
int nKernels = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
~KernelRegistry() {
|
||||||
|
for (auto &[k, v] : kernels)
|
||||||
|
delete std::get<0>(v);
|
||||||
|
}
|
||||||
|
static KernelRegistry &getInstance() {
|
||||||
|
static KernelRegistry instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
bool registerKernel(const KernelAttrs &key, Kernel *kernel, string name) {
|
||||||
|
// TODO: mutliple kernels support: priority and check name
|
||||||
|
IT_ASSERT(kernels.find(key) == kernels.end(),
|
||||||
|
"Kernel already registered");
|
||||||
|
kernels.emplace(key, KernelRecord{kernel, name, ++nKernels});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
||||||
|
return std::get<0>(kernels.at(kernelAttrs));
|
||||||
|
}
|
||||||
|
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
||||||
|
return kernels.at(kernelAttrs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
|
||||||
|
namespace infini { \
|
||||||
|
static const bool _CAT(_register_kernel_, cnt) = \
|
||||||
|
KernelRegistry::getInstance().registerKernel( \
|
||||||
|
KernelAttrs{device, opType, dataType}, new kernel(), name); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
|
||||||
|
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
|
|
@ -0,0 +1,19 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/graph.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class Mutator {
|
||||||
|
private:
|
||||||
|
int candidatesLimit;
|
||||||
|
// // Statistical data
|
||||||
|
// int numTotalCandidates;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Mutator(int candidatesLimit) : candidatesLimit(candidatesLimit){};
|
||||||
|
virtual ~Mutator(){};
|
||||||
|
|
||||||
|
virtual vector<Graph> run(const Graph &in_graph) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,54 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "ref.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
using GuidBaseType = int;
|
||||||
|
|
||||||
|
class Guid {
|
||||||
|
private:
|
||||||
|
GuidBaseType guid;
|
||||||
|
|
||||||
|
private:
|
||||||
|
GuidBaseType generateGuid() {
|
||||||
|
static GuidBaseType guidCnt = 0;
|
||||||
|
return ++guidCnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Guid() { guid = generateGuid(); }
|
||||||
|
Guid(const Guid &rhs) { guid = generateGuid(); }
|
||||||
|
Guid &operator=(const Guid &rhs) {
|
||||||
|
guid = generateGuid();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator GuidBaseType() const { return guid; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class Object {
|
||||||
|
protected:
|
||||||
|
Guid guid;
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual ~Object(){};
|
||||||
|
virtual string toString() const = 0;
|
||||||
|
void print() { std::cout << toString() << std::endl; }
|
||||||
|
Guid getGuid() const { return guid; }
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
|
||||||
|
os << obj.toString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overload for Ref-wrapped Object
|
||||||
|
template <typename T,
|
||||||
|
typename std::enable_if_t<std::is_base_of_v<Object, T>> * = nullptr>
|
||||||
|
inline std::ostream &operator<<(std::ostream &os, const Ref<T> &obj) {
|
||||||
|
os << obj->toString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,180 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/tensor.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
enum class OpType {
|
||||||
|
Unknown = 0,
|
||||||
|
// linear
|
||||||
|
Conv = 100,
|
||||||
|
Matmul,
|
||||||
|
ConvTrans,
|
||||||
|
G2BMM,
|
||||||
|
GBMML,
|
||||||
|
Pad,
|
||||||
|
Slice,
|
||||||
|
Concat,
|
||||||
|
Split,
|
||||||
|
Transpose,
|
||||||
|
Extend,
|
||||||
|
MaxPool,
|
||||||
|
AvgPool,
|
||||||
|
Add,
|
||||||
|
Sub,
|
||||||
|
Mul,
|
||||||
|
Div,
|
||||||
|
Pow,
|
||||||
|
Gather,
|
||||||
|
ReduceMean,
|
||||||
|
Reshape,
|
||||||
|
Identity,
|
||||||
|
// element wise
|
||||||
|
BatchNorm = 200,
|
||||||
|
Softmax,
|
||||||
|
Activation,
|
||||||
|
Resize,
|
||||||
|
//
|
||||||
|
MemBound = 300,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class Device { CPU = 1, CUDA };
|
||||||
|
|
||||||
|
using KernelAttrs = std::tuple<Device, OpType, DataType>;
|
||||||
|
|
||||||
|
class OpRegistry {
|
||||||
|
public:
|
||||||
|
static std::string getOpName(OpType opType) {
|
||||||
|
#define FOP(op) \
|
||||||
|
case OpType::op: \
|
||||||
|
return #op
|
||||||
|
|
||||||
|
switch (opType) {
|
||||||
|
FOP(Unknown);
|
||||||
|
// linear
|
||||||
|
FOP(Conv);
|
||||||
|
FOP(Matmul);
|
||||||
|
FOP(ConvTrans);
|
||||||
|
FOP(G2BMM);
|
||||||
|
FOP(GBMML);
|
||||||
|
FOP(Pad);
|
||||||
|
FOP(Slice);
|
||||||
|
FOP(Concat);
|
||||||
|
FOP(Split);
|
||||||
|
FOP(Transpose);
|
||||||
|
FOP(Extend);
|
||||||
|
FOP(MaxPool);
|
||||||
|
FOP(AvgPool);
|
||||||
|
FOP(Add);
|
||||||
|
FOP(Sub);
|
||||||
|
FOP(Mul);
|
||||||
|
FOP(Div);
|
||||||
|
FOP(Pow);
|
||||||
|
FOP(Gather);
|
||||||
|
FOP(ReduceMean);
|
||||||
|
FOP(Reshape);
|
||||||
|
FOP(Identity);
|
||||||
|
// element wise
|
||||||
|
FOP(BatchNorm);
|
||||||
|
FOP(Softmax);
|
||||||
|
FOP(Activation);
|
||||||
|
//
|
||||||
|
FOP(MemBound);
|
||||||
|
default:
|
||||||
|
IT_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#undef FOP
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class ActType {
|
||||||
|
None,
|
||||||
|
Relu,
|
||||||
|
Sigmoid,
|
||||||
|
Tanh,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OpPerfKey {
|
||||||
|
HashType hash;
|
||||||
|
OpType opType;
|
||||||
|
vector<int> attrs;
|
||||||
|
|
||||||
|
public:
|
||||||
|
OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {})
|
||||||
|
: hash(hash), opType(opType), attrs(attrs) {}
|
||||||
|
bool operator==(const OpPerfKey &rhs) const {
|
||||||
|
if (hash != rhs.hash)
|
||||||
|
return false;
|
||||||
|
if (opType != rhs.opType)
|
||||||
|
return false;
|
||||||
|
if (attrs != rhs.attrs)
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove this function after we use unordered_map in PerfEngine
|
||||||
|
bool operator<(const OpPerfKey &rhs) const {
|
||||||
|
if (hash != rhs.hash)
|
||||||
|
return hash < rhs.hash;
|
||||||
|
if (opType != rhs.opType)
|
||||||
|
return opType < rhs.opType;
|
||||||
|
if (attrs.size() != rhs.attrs.size())
|
||||||
|
return attrs.size() < rhs.attrs.size();
|
||||||
|
for (size_t i = 0; i < attrs.size(); ++i)
|
||||||
|
if (attrs[i] != rhs.attrs[i])
|
||||||
|
return attrs[i] < rhs.attrs[i];
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class OperatorNode : public Object {
|
||||||
|
friend class Kernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
OpType type;
|
||||||
|
TensorVec inputs;
|
||||||
|
TensorVec outputs;
|
||||||
|
// vector<WRef<Operator>> predecessors;
|
||||||
|
// vector<WRef<Operator>> successors;
|
||||||
|
|
||||||
|
public:
|
||||||
|
OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||||
|
: type(opType), inputs(inputs), outputs(outputs) {}
|
||||||
|
virtual vector<Shape> computeShape() const = 0;
|
||||||
|
virtual OpPerfKey getOpPerfKey() const = 0;
|
||||||
|
|
||||||
|
public: // check Op type
|
||||||
|
bool isLinearOp() const;
|
||||||
|
bool isElementWiseOp() const;
|
||||||
|
bool isSplitOp() const;
|
||||||
|
bool isConcatOp() const;
|
||||||
|
bool isComputeOp() const;
|
||||||
|
bool isTransposeOp() const;
|
||||||
|
bool isReshapeOp() const;
|
||||||
|
bool isMemBoundOp() const;
|
||||||
|
|
||||||
|
public: // getter and setter
|
||||||
|
// TensorVec getInputs() { return inputs; }
|
||||||
|
const TensorVec &getInputs() const { return inputs; }
|
||||||
|
// TensorVec getOutputs() { return outputs; }
|
||||||
|
const TensorVec &getOutputs() const { return outputs; }
|
||||||
|
Tensor getInputs(size_t i) { return inputs.at(i); }
|
||||||
|
Tensor getOutput() const {
|
||||||
|
IT_ASSERT(outputs.size() == 1, "Unimplemented");
|
||||||
|
return outputs[0];
|
||||||
|
}
|
||||||
|
OpType getOpType() const { return type; }
|
||||||
|
|
||||||
|
virtual int numInputs() const = 0;
|
||||||
|
virtual int numOutputs() const = 0;
|
||||||
|
virtual HashType hash() const { IT_TODO_HALT(); }
|
||||||
|
virtual HashType hashWithShape() const { IT_TODO_HALT(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template <> struct hash<infini::OpPerfKey> {
|
||||||
|
size_t operator()(const infini::OpPerfKey &key) const { return key.hash; }
|
||||||
|
};
|
||||||
|
} // namespace std
|
|
@ -0,0 +1,36 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class PerfEngine {
|
||||||
|
public:
|
||||||
|
// TODO: Key should be OpPerfKey + Context(maybe implicat) to support
|
||||||
|
// multiple candiate kernels.
|
||||||
|
using Key = std::pair<KernelAttrs, OpPerfKey>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
map<Key, PerfRecord> data;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static PerfEngine &getInstance() {
|
||||||
|
static PerfEngine instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<PerfRecord> getPerfData(const Key &key) {
|
||||||
|
auto it = data.find(key);
|
||||||
|
if (it != data.end()) // find previous evaluating results
|
||||||
|
return data.at(key);
|
||||||
|
else
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setPerfData(const Key &key, const PerfRecord &record) {
|
||||||
|
IT_ASSERT(data.find(key) == data.end(), "Perf data already exist");
|
||||||
|
data.emplace(key, record);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,35 @@
|
||||||
|
#pragma once
|
||||||
|
#include <functional> // hash
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <typename T> using Ref = std::shared_ptr<T>;
|
||||||
|
template <typename T> using WRef = std::weak_ptr<T>;
|
||||||
|
|
||||||
|
template <typename T> struct is_ref : std::false_type {};
|
||||||
|
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
|
||||||
|
template <typename T> struct is_ref<WRef<T>> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
|
||||||
|
static_assert(is_ref<T>::value == false, "Ref should not be nested");
|
||||||
|
return std::make_shared<T>(std::forward<Params>(params)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T, class U,
|
||||||
|
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
|
||||||
|
Ref<T> as(const Ref<U> &ref) {
|
||||||
|
return std::dynamic_pointer_cast<T>(ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<WRef<T>> get_wref_vec(const std::vector<Ref<T>> &vec) {
|
||||||
|
std::vector<WRef<T>> wref_vec;
|
||||||
|
wref_vec.reserve(vec.size());
|
||||||
|
for (const auto &ref : vec)
|
||||||
|
wref_vec.emplace_back(ref);
|
||||||
|
return wref_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,26 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/perf_engine.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class RunEngine {
|
||||||
|
private:
|
||||||
|
Device device;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RunEngine(Device device) : device(device) {}
|
||||||
|
~RunEngine() {}
|
||||||
|
|
||||||
|
void run(const Graph &graph, bool tune = false,
|
||||||
|
bool profiling = false) const;
|
||||||
|
double getPerfTime(const Graph &graph, bool profiling = false) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void printProfilingData(double totTime,
|
||||||
|
const std::map<OpType, double> &opTime,
|
||||||
|
const std::map<OpType, int> &opCnt) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,180 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/tensor_base.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
// TODO: how to deal with this
|
||||||
|
using ShapeElem = int;
|
||||||
|
using Shape = vector<ShapeElem>;
|
||||||
|
class TensorNode : public TensorBaseNode {
|
||||||
|
private:
|
||||||
|
Shape shape;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TensorNode(const Shape &shape, DataType dtype);
|
||||||
|
virtual ~TensorNode() {}
|
||||||
|
string toString() const override;
|
||||||
|
|
||||||
|
size_t size() const;
|
||||||
|
void dataMalloc();
|
||||||
|
|
||||||
|
Shape getDims() const { return shape; }
|
||||||
|
|
||||||
|
size_t getOffset(const Shape &ds) const;
|
||||||
|
using TensorBaseNode::getData;
|
||||||
|
VType getData(const Shape &pos) const;
|
||||||
|
void copyData(VType *dptr);
|
||||||
|
void printData() const;
|
||||||
|
bool equalData(const Tensor &rhs) const;
|
||||||
|
// void setDims(const Dim &dms) { dims = dms; }
|
||||||
|
|
||||||
|
// bool dataRand(int seed = 0) {
|
||||||
|
// if (data == nullptr)
|
||||||
|
// data = new VType[size()];
|
||||||
|
// if (!random_inited)
|
||||||
|
// initFastrand();
|
||||||
|
// // srand(seed);
|
||||||
|
// // faster rand generator; parallel
|
||||||
|
// size_t iEnd = size();
|
||||||
|
// // std::cerr << "Init beginned " << std::endl;
|
||||||
|
// #pragma omp parallel for
|
||||||
|
// for (size_t i = 0; i < iEnd; ++i)
|
||||||
|
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
||||||
|
// 10000;
|
||||||
|
// // std::cerr << "Init finished" << std::endl;
|
||||||
|
// computed = ComputedFull;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setScalar(VType val) {
|
||||||
|
// if (data == nullptr || !dims.empty())
|
||||||
|
// return false;
|
||||||
|
// data[0] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(const Dim &ds, VType val) {
|
||||||
|
// if (data == nullptr || ds.size() != dims.size())
|
||||||
|
// return false;
|
||||||
|
// data[getOffset(ds)] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(size_t pos, VType val) {
|
||||||
|
// if (data == nullptr || pos >= size())
|
||||||
|
// return false;
|
||||||
|
// data[pos] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
||||||
|
|
||||||
|
// VType getBroadcastData(const Dim &ds) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// auto offset = getBroadcastOffset(ds);
|
||||||
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getBroadcastData(size_t pos) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// return data[pos % size()];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// size_t getBroadcastOffset(const Dim &ds) {
|
||||||
|
// assert(ds.size() >= dims.size());
|
||||||
|
// auto nDim = dims.size();
|
||||||
|
// auto nBroadcastDim = ds.size() - nDim;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
||||||
|
// dims[i])
|
||||||
|
// return (size_t)-1;
|
||||||
|
// size_t idx = 0;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
||||||
|
// return idx;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void itInit() { it = Dim(dims.size(), 0); }
|
||||||
|
|
||||||
|
// void itReset() {
|
||||||
|
// itInit();
|
||||||
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||||
|
// it[i] = 0;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool itValid() {
|
||||||
|
// if (it.size() != dims.size())
|
||||||
|
// return false;
|
||||||
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||||
|
// if (it[i] >= dims[i])
|
||||||
|
// return false;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const Dim &itGet() { return it; }
|
||||||
|
|
||||||
|
// void itNext() {
|
||||||
|
// auto p = it.size() - 1;
|
||||||
|
// it[p] += 1;
|
||||||
|
// while (p >= 1) {
|
||||||
|
// if (it[p] == dims[p]) {
|
||||||
|
// it[p] = 0;
|
||||||
|
// it[--p] += 1;
|
||||||
|
// } else
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TensorType getType() const { return type; }
|
||||||
|
// void setType(TensorType ty) { type = ty; }
|
||||||
|
|
||||||
|
// static inline void initFastrand() {
|
||||||
|
// assert(omp_get_max_threads() <= 256);
|
||||||
|
// // srand(0); // constant seed for test
|
||||||
|
// // align random_seed to avoid false sharing
|
||||||
|
// for (int i = 0; i < 256 * 16; ++i) {
|
||||||
|
// // random_seed[i] = rand();
|
||||||
|
// // constant random seed for test
|
||||||
|
// random_seed[i] = i;
|
||||||
|
// }
|
||||||
|
// random_inited = true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// static inline int fastrand(int &g_seed) {
|
||||||
|
// g_seed = (214013 * g_seed + 2531011);
|
||||||
|
// return (g_seed >> 16) & 0x7FFF;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
||||||
|
// assert(!splittingPoints.empty());
|
||||||
|
// return &splittingPoints;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
||||||
|
// assert(!value.empty());
|
||||||
|
// splittingPoints = value;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void printSplittingPoints() {
|
||||||
|
// if (splittingPoints.empty())
|
||||||
|
// printf("Empty SplittingPoints");
|
||||||
|
// else {
|
||||||
|
// printf("[");
|
||||||
|
// for (auto &vs : splittingPoints) {
|
||||||
|
// printf("[");
|
||||||
|
// for (auto v : vs)
|
||||||
|
// printf("%2d,", v);
|
||||||
|
// printf("],");
|
||||||
|
// }
|
||||||
|
// printf("]");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void initSplittingPoints() {
|
||||||
|
// splittingPoints.resize(getDims().size()); }
|
||||||
|
|
||||||
|
// void printShape();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,261 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/object.h"
|
||||||
|
#include "core/ref.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
// class Tensor;
|
||||||
|
class TensorBaseNode;
|
||||||
|
class TensorNode;
|
||||||
|
class OperatorNode;
|
||||||
|
class GraphNode;
|
||||||
|
|
||||||
|
using TensorBase = Ref<TensorBaseNode>;
|
||||||
|
using Tensor = Ref<TensorNode>;
|
||||||
|
using Operator = Ref<OperatorNode>;
|
||||||
|
using Graph = Ref<GraphNode>;
|
||||||
|
|
||||||
|
using TensorVec = vector<Tensor>;
|
||||||
|
using OpVec = vector<Operator>;
|
||||||
|
|
||||||
|
using VType = uint32_t;
|
||||||
|
|
||||||
|
enum class DataType {
|
||||||
|
Float32,
|
||||||
|
Int32,
|
||||||
|
};
|
||||||
|
|
||||||
|
class TensorBaseNode : public Object {
|
||||||
|
public:
|
||||||
|
// enum TensorType {
|
||||||
|
// Input,
|
||||||
|
// Weight,
|
||||||
|
// Invalid,
|
||||||
|
// NotCounted,
|
||||||
|
// };
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int dim;
|
||||||
|
|
||||||
|
DataType dtype;
|
||||||
|
vector<WRef<TensorBaseNode>> inputOf;
|
||||||
|
WRef<TensorBaseNode> outputOf;
|
||||||
|
// TODO: Ref<void> -> Ref<Blob>
|
||||||
|
Ref<VType[]> data;
|
||||||
|
// ComputeState computed;
|
||||||
|
// static int random_seed[256 * 16];
|
||||||
|
// static bool random_inited;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TensorBaseNode(int dim, DataType dtype);
|
||||||
|
virtual ~TensorBaseNode() {}
|
||||||
|
|
||||||
|
Ref<VType[]> getDataPtr() const { return data; }
|
||||||
|
VType getData(size_t offset) const;
|
||||||
|
|
||||||
|
DataType getDType() const { return dtype; }
|
||||||
|
|
||||||
|
// uint64_t getHash() const { return hash; }
|
||||||
|
|
||||||
|
// void setInputOf(const OpVec &ops) {
|
||||||
|
// inputOf.clear();
|
||||||
|
// for (const auto &op : ops)
|
||||||
|
// inputOf.emplace_back(op);
|
||||||
|
// }
|
||||||
|
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
|
||||||
|
// void setOutputOf(Operator op) { outputOf = op; }
|
||||||
|
|
||||||
|
// const OpVec &getInputOf() { return inputOf; }
|
||||||
|
// Operator *getOutputOf() { return outputOf; }
|
||||||
|
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||||
|
|
||||||
|
// const Dim &getDims() const { return dims; }
|
||||||
|
// void setDims(const Dim &dms) { dims = dms; }
|
||||||
|
|
||||||
|
// bool dataRand(int seed = 0) {
|
||||||
|
// if (data == nullptr)
|
||||||
|
// data = new VType[size()];
|
||||||
|
// if (!random_inited)
|
||||||
|
// initFastrand();
|
||||||
|
// // srand(seed);
|
||||||
|
// // faster rand generator; parallel
|
||||||
|
// size_t iEnd = size();
|
||||||
|
// // std::cerr << "Init beginned " << std::endl;
|
||||||
|
// #pragma omp parallel for
|
||||||
|
// for (size_t i = 0; i < iEnd; ++i)
|
||||||
|
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
||||||
|
// 10000;
|
||||||
|
// // std::cerr << "Init finished" << std::endl;
|
||||||
|
// computed = ComputedFull;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setScalar(VType val) {
|
||||||
|
// if (data == nullptr || !dims.empty())
|
||||||
|
// return false;
|
||||||
|
// data[0] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(const Dim &ds, VType val) {
|
||||||
|
// if (data == nullptr || ds.size() != dims.size())
|
||||||
|
// return false;
|
||||||
|
// data[getOffset(ds)] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(size_t pos, VType val) {
|
||||||
|
// if (data == nullptr || pos >= size())
|
||||||
|
// return false;
|
||||||
|
// data[pos] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
||||||
|
|
||||||
|
// VType getData(const Dim &ds) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// auto offset = getOffset(ds);
|
||||||
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getData(size_t pos) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// assert(pos < size());
|
||||||
|
// return data[pos];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType *getDataPtr() const { return data; }
|
||||||
|
|
||||||
|
// size_t getOffset(const Dim &ds) {
|
||||||
|
// auto nDim = ds.size();
|
||||||
|
// assert(dims.size() == nDim);
|
||||||
|
// if (ds.empty())
|
||||||
|
// return 0;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// if (ds[i] < 0 || ds[i] >= dims[i])
|
||||||
|
// return (size_t)-1;
|
||||||
|
// size_t idx = ds[0];
|
||||||
|
// size_t dm = 0;
|
||||||
|
// while (++dm < nDim)
|
||||||
|
// idx = idx * dims[dm] + ds[dm];
|
||||||
|
// return idx;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getBroadcastData(const Dim &ds) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// auto offset = getBroadcastOffset(ds);
|
||||||
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getBroadcastData(size_t pos) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// return data[pos % size()];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// size_t getBroadcastOffset(const Dim &ds) {
|
||||||
|
// assert(ds.size() >= dims.size());
|
||||||
|
// auto nDim = dims.size();
|
||||||
|
// auto nBroadcastDim = ds.size() - nDim;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
||||||
|
// dims[i])
|
||||||
|
// return (size_t)-1;
|
||||||
|
// size_t idx = 0;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
||||||
|
// return idx;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void itInit() { it = Dim(dims.size(), 0); }
|
||||||
|
|
||||||
|
// void itReset() {
|
||||||
|
// itInit();
|
||||||
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||||
|
// it[i] = 0;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool itValid() {
|
||||||
|
// if (it.size() != dims.size())
|
||||||
|
// return false;
|
||||||
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||||
|
// if (it[i] >= dims[i])
|
||||||
|
// return false;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const Dim &itGet() { return it; }
|
||||||
|
|
||||||
|
// void itNext() {
|
||||||
|
// auto p = it.size() - 1;
|
||||||
|
// it[p] += 1;
|
||||||
|
// while (p >= 1) {
|
||||||
|
// if (it[p] == dims[p]) {
|
||||||
|
// it[p] = 0;
|
||||||
|
// it[--p] += 1;
|
||||||
|
// } else
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// size_t size() const {
|
||||||
|
// size_t sz = 1;
|
||||||
|
// auto dm = dims.size();
|
||||||
|
// while (dm > 0)
|
||||||
|
// sz *= dims[--dm];
|
||||||
|
// return sz;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TensorType getType() const { return type; }
|
||||||
|
// void setType(TensorType ty) { type = ty; }
|
||||||
|
|
||||||
|
// static inline void initFastrand() {
|
||||||
|
// assert(omp_get_max_threads() <= 256);
|
||||||
|
// // srand(0); // constant seed for test
|
||||||
|
// // align random_seed to avoid false sharing
|
||||||
|
// for (int i = 0; i < 256 * 16; ++i) {
|
||||||
|
// // random_seed[i] = rand();
|
||||||
|
// // constant random seed for test
|
||||||
|
// random_seed[i] = i;
|
||||||
|
// }
|
||||||
|
// random_inited = true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// static inline int fastrand(int &g_seed) {
|
||||||
|
// g_seed = (214013 * g_seed + 2531011);
|
||||||
|
// return (g_seed >> 16) & 0x7FFF;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
||||||
|
// assert(!splittingPoints.empty());
|
||||||
|
// return &splittingPoints;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
||||||
|
// assert(!value.empty());
|
||||||
|
// splittingPoints = value;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void printSplittingPoints() {
|
||||||
|
// if (splittingPoints.empty())
|
||||||
|
// printf("Empty SplittingPoints");
|
||||||
|
// else {
|
||||||
|
// printf("[");
|
||||||
|
// for (auto &vs : splittingPoints) {
|
||||||
|
// printf("[");
|
||||||
|
// for (auto v : vs)
|
||||||
|
// printf("%2d,", v);
|
||||||
|
// printf("],");
|
||||||
|
// }
|
||||||
|
// printf("]");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void initSplittingPoints() {
|
||||||
|
// splittingPoints.resize(getDims().size()); }
|
||||||
|
|
||||||
|
// void printShape();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class MatchComputationKernel : public Pass {
|
||||||
|
public:
|
||||||
|
MatchComputationKernel(Derivator &derivator)
|
||||||
|
: Pass(derivator, "MatchComputationKernel") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class MatchMemBoundKernel : public Pass {
|
||||||
|
public:
|
||||||
|
MatchMemBoundKernel(Derivator &derivator)
|
||||||
|
: Pass(derivator, "MatchMemBoundKernel") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,41 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/derivator.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Pass {
|
||||||
|
private:
|
||||||
|
VecExpr transformations;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Derivator &derivator;
|
||||||
|
string passName;
|
||||||
|
/**
|
||||||
|
* @brief // False if does not add log in Derivator. It should be false for
|
||||||
|
* single Pass test to avoid mismatch of passInfos and passMsgs due to
|
||||||
|
* different number of "run" and "nextStep".
|
||||||
|
*/
|
||||||
|
bool enableLogging, enableDebug;
|
||||||
|
|
||||||
|
virtual void transform(Formula &origin, int depth, Expr &rCur) = 0;
|
||||||
|
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
|
||||||
|
const string &ruleInfo = "");
|
||||||
|
|
||||||
|
Var getNewVar();
|
||||||
|
string newTensorName();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void initialize(Formula &origin, const Expr &rCur);
|
||||||
|
void finalize();
|
||||||
|
|
||||||
|
public:
|
||||||
|
Pass(Derivator &derivator, const string &passName);
|
||||||
|
virtual ~Pass();
|
||||||
|
|
||||||
|
void run(Formula &origin, int dfsDepth, Expr &rCur);
|
||||||
|
void setEnableLogging(bool value);
|
||||||
|
void setEnableDebug(bool value);
|
||||||
|
const VecExpr &getTransformations();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
#include "nnet/ReplaceKit.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule1VariableSplit : public Pass {
|
||||||
|
public:
|
||||||
|
Rule1VariableSplit(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule1VariableSplit") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
vector<Replace> getSplitableVar(const RangeOp &rangeOp);
|
||||||
|
Expr replaceIters(Expr cur, const Replace &replace);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,29 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
#include "nnet/ReplaceKit.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule2VariableMerging : public Pass {
|
||||||
|
private:
|
||||||
|
map<int, vector<Var>> substituteRules;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Rule2VariableMerging(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule2VariableMerging") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
|
||||||
|
vector<Replace> getMergableReplaces(RangeOp rangeOp, int depth);
|
||||||
|
optional<Replace> getReplaceMergingTwoLoopIters(const RangeOp &rangeOp,
|
||||||
|
pair<Iterator, int> pairA,
|
||||||
|
pair<Iterator, int> pairB,
|
||||||
|
const IteratorTable &exprIT,
|
||||||
|
int tensorID);
|
||||||
|
optional<Replace> getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
|
||||||
|
pair<Iterator, int> pa,
|
||||||
|
pair<Iterator, int> pb);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,19 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule3StageSplit : public Pass {
|
||||||
|
private:
|
||||||
|
map<int, vector<Var>> substituteRules;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Rule3StageSplit(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule3StageSplit") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
vector<vector<Var>> getSplitSummationIters(RangeOp rangeOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,22 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule4StageMerging : public Pass {
|
||||||
|
bool success, mergeStageWithCalc;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Rule4StageMerging(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule4StageMerging"), success(false),
|
||||||
|
mergeStageWithCalc(false) {}
|
||||||
|
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
|
||||||
|
bool mergeStageWithCalc = false);
|
||||||
|
bool isSuccessful();
|
||||||
|
void setMergeStageWithCalc(bool value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule5RangeRelaxation : public Pass {
|
||||||
|
public:
|
||||||
|
Rule5RangeRelaxation(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule5RangeRelaxation") {}
|
||||||
|
Expr rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,17 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule6KenerlMatching : public Pass {
|
||||||
|
public:
|
||||||
|
Rule6KenerlMatching(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule6KenerlMatching") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
// RE: seperating this func is a choice.
|
||||||
|
VecExpr matchElementWise(const RangeOp &rangeOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule7DLT : public Pass {
|
||||||
|
public:
|
||||||
|
Rule7DLT(Derivator &derivator) : Pass(derivator, "Rule7DLT") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
Expr buildDLTSingleRangeOp(const RangeOp &original, const Expr &newSummand);
|
||||||
|
vector<int> getFactors();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,48 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
#include "nnet/ReplaceKit.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule8GuidedDLT : public Pass {
|
||||||
|
public:
|
||||||
|
Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {}
|
||||||
|
VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur,
|
||||||
|
bool debug = false);
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
/**
|
||||||
|
* @brief If only one row miss match (more iterators mismatch), directly do
|
||||||
|
* data layout construction according to the IT.
|
||||||
|
*
|
||||||
|
* @return Expr Return nullptr if failed.
|
||||||
|
*/
|
||||||
|
Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch,
|
||||||
|
const IteratorTable &exprIT, const Pattern &pattern);
|
||||||
|
/**
|
||||||
|
* @brief Check whether two iterators overlap each other. If overlapping, we
|
||||||
|
* cannot simply reconstruct the tensor into a new one by seperate all
|
||||||
|
* iterators into different dimensions.
|
||||||
|
*/
|
||||||
|
bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT,
|
||||||
|
int tensorID);
|
||||||
|
/**
|
||||||
|
* @brief Only product of two tensors can be guided DLTed.
|
||||||
|
*
|
||||||
|
* @param cur
|
||||||
|
* @return true
|
||||||
|
* @return false
|
||||||
|
*/
|
||||||
|
bool statisfyGuidedDLT(RangeOp cur) const;
|
||||||
|
/**
|
||||||
|
* @brief Deal with output DLT mismatch only.
|
||||||
|
*/
|
||||||
|
Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch,
|
||||||
|
const IteratorTable &exprIT,
|
||||||
|
const Pattern &pattern);
|
||||||
|
Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace,
|
||||||
|
vector<Var> tensorDimAxes, vector<int> newShape);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule90TwoStageElementWise : public Pass {
|
||||||
|
public:
|
||||||
|
Rule90TwoStageElementWise(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule90TwoStageElementWise") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
VecExpr matchTwoStageElementWise(const RangeOp &rangeOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule91MergeStagesWithSum : public Pass {
|
||||||
|
public:
|
||||||
|
Rule91MergeStagesWithSum(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule91MergeStagesWithSum") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Rule9RangeMagnify : public Pass {
|
||||||
|
public:
|
||||||
|
Rule9RangeMagnify(Derivator &derivator)
|
||||||
|
: Pass(derivator, "Rule9RangeMagnify") {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,46 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/expr.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
struct Replace {
|
||||||
|
int iteratorType;
|
||||||
|
vector<Var> oldIters; // i_1, ...
|
||||||
|
vector<Var> newIters; // j_1, ...
|
||||||
|
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
|
||||||
|
VecExpr psis; // i_1=\psi_1(j_1, ...)
|
||||||
|
vector<VarRangePair> newVarRanges;
|
||||||
|
|
||||||
|
bool isReplaced(Var var) const {
|
||||||
|
for (const auto &iter : oldIters)
|
||||||
|
if (iter->equal(var))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
string toReadable() const {
|
||||||
|
string ret = "Old iters: " + serializeVec(oldIters) +
|
||||||
|
", new iters: " + serializeVec(newIters);
|
||||||
|
ret += " phis: " + serializeVec(phis) + " psis: " + serializeVec(psis);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReplaceKit {
|
||||||
|
public:
|
||||||
|
static RangeOp replaceRangeOpIterator(const RangeOp &rangeOp,
|
||||||
|
const Replace &replace,
|
||||||
|
const Expr &replacedSummand);
|
||||||
|
static Subscript buildSubscirptForLoopVarReplace(const RangeOp &inner,
|
||||||
|
const Replace &replace);
|
||||||
|
static RangeOp buildDLTOuterRangeOp(const RangeOp &original,
|
||||||
|
const Subscript &subscriptedNewRangeOp);
|
||||||
|
static Expr replaceMultipleExprs(const Expr &cur,
|
||||||
|
const vector<Var> &patterns,
|
||||||
|
const VecExpr &replacements,
|
||||||
|
bool simplify = true);
|
||||||
|
static Expr replaceExpr(const Expr &cur, const Expr &pattern,
|
||||||
|
const Expr &replacement);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Visitor/StrideVisitor.h"
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class AsTVMVisitor : public Functor<std::string(void)> {
|
||||||
|
private:
|
||||||
|
int nStage = 0, curStage = -1;
|
||||||
|
std::unordered_map<std::string, int> offset;
|
||||||
|
std::vector<std::string> inputs;
|
||||||
|
std::string output;
|
||||||
|
std::vector<std::string> pythonVars;
|
||||||
|
std::vector<std::vector<int>> inputShapes;
|
||||||
|
std::vector<int> outputShape;
|
||||||
|
std::string stmts;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::string getStmts() const;
|
||||||
|
|
||||||
|
const std::vector<std::string> &getInputs() const { return inputs; }
|
||||||
|
const std::string &getOutput() const { return output; }
|
||||||
|
|
||||||
|
const std::vector<std::vector<int>> &getInputShapes() const {
|
||||||
|
return inputShapes;
|
||||||
|
}
|
||||||
|
const std::vector<int> &getOutputShape() const { return outputShape; }
|
||||||
|
|
||||||
|
std::string visit_(const Constant &c) override;
|
||||||
|
std::string visit_(const BinaryOp &c) override;
|
||||||
|
std::string visit_(const Func &c) override;
|
||||||
|
std::string visit_(const RangeOp &c) override;
|
||||||
|
std::string visit_(const Subscript &c) override;
|
||||||
|
std::string visit_(const Var &c) override;
|
||||||
|
std::string visit_(const Tensor &c) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,22 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class CheckOOBVisitor : public ExprTreeVisitor {
|
||||||
|
RangeOp rangeOp;
|
||||||
|
bool detect = false;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CheckOOBVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 0, 0, _verobse) {}
|
||||||
|
void visit_(const Subscript &c) override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief
|
||||||
|
* @return true If there is OOB
|
||||||
|
* @return false If there is no OOB
|
||||||
|
*/
|
||||||
|
bool checkRangeOp(const RangeOp &_rangeOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// Clone ExprNodes in a stage except Tensor, Var, and Constant nodes.
|
||||||
|
class CloneMutator : public Mutator {
|
||||||
|
public:
|
||||||
|
CloneMutator() : Mutator(false) {}
|
||||||
|
Expr visit_(const Constant &c) override;
|
||||||
|
Expr visit_(const Var &c) override;
|
||||||
|
Expr visit_(const Tensor &c) override;
|
||||||
|
Expr clone(const Expr &c) { return dispatch(c); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class CompareMultiFormulasVisitor : public ExprTreeVisitor {
|
||||||
|
vector<VarRangePair> newSumVarRanges;
|
||||||
|
RangeOp newRangeOp;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CompareMultiFormulasVisitor() : ExprTreeVisitor() {}
|
||||||
|
bool compare(const VecExpr &roots);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class CountRoutineVisitor : public ExprTreeVisitor {
|
||||||
|
private:
|
||||||
|
vector<int> cnts;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CountRoutineVisitor(int _verobse = 0)
|
||||||
|
: ExprTreeVisitor(1, 1, 1, 1, _verobse) {}
|
||||||
|
void visit_(const Tensor &c) override;
|
||||||
|
vector<int> count(const Expr &root);
|
||||||
|
bool match(const Expr &root, int nMatmul = 0, int nConv = 0,
|
||||||
|
int nElement = 0, int nSg2bmm = 0, int nLongformerGBMM = 0);
|
||||||
|
};
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,25 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class FullPrinterVisitor : public ExprTreeVisitor {
|
||||||
|
private:
|
||||||
|
vector<tuple<string, Routine, Tensor>> q;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FullPrinterVisitor(int _verobse = 0)
|
||||||
|
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
||||||
|
void visit_(const Tensor &c) override;
|
||||||
|
|
||||||
|
string print(const Expr &root);
|
||||||
|
/**
|
||||||
|
* @brief Get all tensors & OPs in a reversed order
|
||||||
|
*
|
||||||
|
* @param root
|
||||||
|
* @return vector<<Output TensorName, RoutineNode, output tensor in NNet>>
|
||||||
|
*/
|
||||||
|
const vector<tuple<string, Routine, Tensor>> &traverse(const Expr &root);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,22 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// Get all tensors in the stage
|
||||||
|
class GetTensorsVisitor : public ExprTreeVisitor {
|
||||||
|
private:
|
||||||
|
unordered_map<string, Tensor> tensors;
|
||||||
|
|
||||||
|
void visit_(const Tensor &c) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GetTensorsVisitor(int _verobse = 0)
|
||||||
|
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
||||||
|
auto get(const Expr &c) {
|
||||||
|
dispatch(c);
|
||||||
|
return tensors;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,31 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// Calculate hash for a normal form, starting at a RangeOp
|
||||||
|
class HashVisitor : public Functor<HashType(void)> {
|
||||||
|
inline const static HashType BKDR_SEED[] = {131, 313, 10007, 65599};
|
||||||
|
|
||||||
|
PtrUmap<Iterator, int> varHash;
|
||||||
|
int nLoopVars = 0;
|
||||||
|
PtrUmap<Iterator, int> name2id;
|
||||||
|
vector<int> rootId;
|
||||||
|
vector<bool> haveAlias;
|
||||||
|
int nVars = 0;
|
||||||
|
vector<HashType> power;
|
||||||
|
|
||||||
|
private:
|
||||||
|
HashType visit_(const Constant &c) override;
|
||||||
|
HashType visit_(const BinaryOp &c) override;
|
||||||
|
HashType visit_(const RangeOp &c) override;
|
||||||
|
HashType visit_(const Subscript &c) override;
|
||||||
|
HashType visit_(const Tensor &c) override;
|
||||||
|
HashType visit_(const Var &c) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
HashVisitor(int _verobse = 0) : Functor(_verobse) {}
|
||||||
|
HashType getHash(const Expr &c);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,23 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class InputVisitor : public ExprTreeVisitor {
|
||||||
|
vector<Tensor> inputs;
|
||||||
|
|
||||||
|
public:
|
||||||
|
int nInputs = 0;
|
||||||
|
InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
||||||
|
void visit_(const Tensor &c) override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the all inputs in the netsed stages
|
||||||
|
*/
|
||||||
|
vector<Tensor> getInputs(const RangeOp &_rangeOp) {
|
||||||
|
dispatch(_rangeOp);
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,55 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Interpreter : public Functor<int()> {
|
||||||
|
public:
|
||||||
|
using ttype = int; // Test data type
|
||||||
|
using rtype = int; // Return data type
|
||||||
|
using Position = vector<int>;
|
||||||
|
using Inputs = unordered_map<string, Ref<vector<ttype>>>;
|
||||||
|
using Iteration = PtrUmap<Var, int>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// cache the input value
|
||||||
|
Inputs inputs;
|
||||||
|
vector<Iteration> iterations;
|
||||||
|
vector<Position> positions;
|
||||||
|
|
||||||
|
rtype visit_(const Constant &c) override;
|
||||||
|
rtype visit_(const BinaryOp &c) override;
|
||||||
|
rtype visit_(const RangeOp &c) override;
|
||||||
|
rtype visit_(const Subscript &c) override;
|
||||||
|
rtype visit_(const Var &c) override;
|
||||||
|
rtype visit_(const Tensor &c) override;
|
||||||
|
// int visit_(const Func &c); // Future work
|
||||||
|
|
||||||
|
static Inputs genInputStartingFromZero(const RangeOp &range);
|
||||||
|
|
||||||
|
public:
|
||||||
|
Interpreter(Inputs _inputs, int _verbose = 0)
|
||||||
|
: Functor(_verbose), inputs(_inputs) {}
|
||||||
|
Interpreter(RangeOp range, int _verbose = 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate the output at specified poistions
|
||||||
|
*
|
||||||
|
* @param expr The expression to be calculated.
|
||||||
|
* @param poses Positions of output.
|
||||||
|
* @return vector<int> Value of output.
|
||||||
|
*/
|
||||||
|
vector<rtype> interpret(const Expr &expr, const vector<Position> &poses);
|
||||||
|
/**
|
||||||
|
* @brief Calculate the output at equally spaced positions
|
||||||
|
*
|
||||||
|
* @param expr The expression to be calculated.
|
||||||
|
* @param nPoses The number of calculated output positions.
|
||||||
|
* @return vector<int> Value of output.
|
||||||
|
*/
|
||||||
|
vector<rtype> interpretUniformSample(const RangeOp &range,
|
||||||
|
int nPoses = 100);
|
||||||
|
vector<rtype> interpretAllOutput(const RangeOp &range);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,14 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class MatchReshapeVisitor : public Functor<bool(void)> {
|
||||||
|
private:
|
||||||
|
PtrMap<Iterator, int> _coefficient;
|
||||||
|
|
||||||
|
public:
|
||||||
|
bool visit_(const RangeOp &c) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,60 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Visitor/StrideVisitor.h"
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class MatchTableVisitor
|
||||||
|
: public Functor<void(const Tensor &, int dim, optional<int> stride)> {
|
||||||
|
private:
|
||||||
|
// Var -> {(tensor, dim)}
|
||||||
|
Appearance appearance;
|
||||||
|
vector<Tensor> tensors;
|
||||||
|
vector<Subscript> subscripts;
|
||||||
|
StrideTable strideTable;
|
||||||
|
PtrMap<Iterator, vector<vector<int>>>
|
||||||
|
strideInDim; // [Iterator][tensorID][dim]=stride
|
||||||
|
|
||||||
|
// Intermediate variable
|
||||||
|
// product of a sub-exprtree: Stride has to be done in two DFS
|
||||||
|
SubexprSride subexprStride;
|
||||||
|
bool hasUnsupportedOp = false;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MatchTableVisitor(int _verobse = 0) : Functor(_verobse) {}
|
||||||
|
void visit_(const BinaryOp &c, const Tensor &tensor, int dim,
|
||||||
|
optional<int> stride) override;
|
||||||
|
void visit_(const Subscript &c, const Tensor &tensor, int dim,
|
||||||
|
optional<int> stride) override;
|
||||||
|
void visit_(const Var &c, const Tensor &tensor, int dim,
|
||||||
|
optional<int> stride) override;
|
||||||
|
void visit_(const Constant &c, const Tensor &tensor, int dim,
|
||||||
|
optional<int> stride) override;
|
||||||
|
// void visit_(const Tensor &c, const Tensor &tensor) override;
|
||||||
|
|
||||||
|
[[nodiscard]] bool operator()(const RangeOp &e) {
|
||||||
|
hasUnsupportedOp = false;
|
||||||
|
// get the location and stride of each iterator
|
||||||
|
auto mulOp = as<BinaryOpNode>(e->getSummand());
|
||||||
|
// TODO [feature]: support complex index exprs
|
||||||
|
if (!mulOp || mulOp->getOpType() != OpType::Mul) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
StrideVisitor strideVisitor(0);
|
||||||
|
subexprStride = strideVisitor.getFormulaStride(e);
|
||||||
|
dispatch(mulOp->getLhs(), nullptr, 0, 0);
|
||||||
|
dispatch(mulOp->getRhs(), nullptr, 0, 0);
|
||||||
|
subscripts.emplace_back(as<SubscriptNode>(mulOp->getLhs()));
|
||||||
|
subscripts.emplace_back(as<SubscriptNode>(mulOp->getRhs()));
|
||||||
|
assert(tensors.size() == subscripts.size());
|
||||||
|
assert(tensors.size() < 5);
|
||||||
|
return !hasUnsupportedOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getResult() const {
|
||||||
|
return tuple(appearance, tensors, strideTable, subscripts);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class MatmulTransposeMutator : public Mutator {
|
||||||
|
Derivator &derivator;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MatmulTransposeMutator(Derivator &derivator)
|
||||||
|
: Mutator(1), derivator(derivator) {}
|
||||||
|
VecExpr transpose(const Tensor &tensor);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Tensor transposeInput(const Tensor &tensor);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class MergeMemboundMutator : public Mutator {
|
||||||
|
VecExpr kernels;
|
||||||
|
int curDepth; // from the last one to the first one
|
||||||
|
Expr visit_(const Tensor &c) override;
|
||||||
|
// FIXME: duplicate code
|
||||||
|
Expr rule4StageMerging(Expr &rCur, bool mergeStageWithCalc);
|
||||||
|
bool checkEmpty();
|
||||||
|
|
||||||
|
public:
|
||||||
|
MergeMemboundMutator(const VecExpr &kernels)
|
||||||
|
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
|
||||||
|
Expr merge(bool allowEmptyMembound = false);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,43 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Since the output positions of operators always start from 0, we have
|
||||||
|
* to offset them if the the boundary expression of is not 0.
|
||||||
|
*/
|
||||||
|
class PatternMatcher : public Functor<void(void)> {
|
||||||
|
private:
|
||||||
|
Derivator &derivator;
|
||||||
|
bool hasNonZeroRange;
|
||||||
|
const RangeOp originalCur;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PatternMatcher(Derivator &derivator, const RangeOp &cur);
|
||||||
|
/**
|
||||||
|
* @brief Get the Cur whose loop vars are all offset to [0, x). Since
|
||||||
|
* operator outputs start from 0, RangeOp has to be aligned.
|
||||||
|
*/
|
||||||
|
RangeOp getOffsetCur();
|
||||||
|
/**
|
||||||
|
* @brief Add outer RangeOp to map the original positions to the new
|
||||||
|
* positions staring from 0.
|
||||||
|
*
|
||||||
|
* @param exprs Tensors from matched exprs
|
||||||
|
*/
|
||||||
|
VecExpr applyWrapper(const VecExpr &exprs);
|
||||||
|
|
||||||
|
VecExpr matchWithPattern(const RangeOp &rangeOp, const Pattern &pattern);
|
||||||
|
|
||||||
|
private:
|
||||||
|
VecExpr matchKernel(const Pattern &pattern, const RangeOp &rangeOp,
|
||||||
|
IteratorTable &exprIT);
|
||||||
|
// get reverse tensor and iterator map ([pattern tensor/iter ID] ->
|
||||||
|
// real)
|
||||||
|
Expr matchKernelWithTensorMap(const Pattern &pattern,
|
||||||
|
const RangeOp &rangeOp,
|
||||||
|
IteratorTable &exprIT);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,25 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class RangeMagnifyVisitor : public Mutator {
|
||||||
|
vector<VarRangePair> newSumVarRanges;
|
||||||
|
RangeOp newRangeOp;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RangeMagnifyVisitor() : Mutator(0) {}
|
||||||
|
Expr visit_(const RangeOp &c) override;
|
||||||
|
Expr visit_(const Subscript &c) override;
|
||||||
|
/**
|
||||||
|
* @brief
|
||||||
|
*
|
||||||
|
* @param root
|
||||||
|
* @param _newSumVarRanges
|
||||||
|
* @return RangeOp nullptr if failed to magnify
|
||||||
|
*/
|
||||||
|
RangeOp magnify(const RangeOp &root,
|
||||||
|
const vector<VarRangePair> &_newSumVarRanges);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
using RangeMap = PtrMap<Iterator, Range>;
|
||||||
|
class RangeRelaxFunctor : public Functor<RangeMap()> {
|
||||||
|
RangeOp rangeOp;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RangeRelaxFunctor(RangeOp _rangeOp) : Functor(false), rangeOp(_rangeOp) {}
|
||||||
|
RangeMap visit_(const BinaryOp &c) override;
|
||||||
|
RangeMap visit_(const RangeOp &c) override;
|
||||||
|
RangeMap visit_(const Subscript &c) override;
|
||||||
|
RangeMap intersectRangeMaps(const RangeMap &a, const RangeMap &b);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// Replace node according to its address the summand
|
||||||
|
// Only subscript and tensor are supported now.
|
||||||
|
class ReplaceNodeMutator : public Mutator {
|
||||||
|
int nSubscripts = 0;
|
||||||
|
ExprNode *target;
|
||||||
|
Expr replacement;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ReplaceNodeMutator() : Mutator(0) {}
|
||||||
|
Expr visit_(const Subscript &c) override;
|
||||||
|
Expr visit_(const Tensor &c) override;
|
||||||
|
Expr replace(const Expr &root, ExprNode *_target, const Expr &_replace);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,33 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class ReplaceVariable : public Mutator {
|
||||||
|
VecExpr patterns, replacements;
|
||||||
|
map<HashType, int> patternHash;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ReplaceVariable(Expr _pattern, Expr _replacement) : Mutator(false) {
|
||||||
|
set({_pattern}, {_replacement});
|
||||||
|
}
|
||||||
|
ReplaceVariable(const map<string, pair<Expr, Expr>> &mapping)
|
||||||
|
: Mutator(false) {
|
||||||
|
VecExpr _patterns, _replacements;
|
||||||
|
for (const auto &[_, v] : mapping) {
|
||||||
|
_patterns.emplace_back(v.first);
|
||||||
|
_replacements.emplace_back(v.second);
|
||||||
|
}
|
||||||
|
set(_patterns, _replacements);
|
||||||
|
}
|
||||||
|
Expr visit_(const BinaryOp &c) override;
|
||||||
|
// NOT recur to the next stage
|
||||||
|
Expr visit_(const RangeOp &c) override;
|
||||||
|
Expr visit_(const Var &c) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void set(VecExpr _pattern, VecExpr _replacement);
|
||||||
|
Expr match(const Expr &c);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,52 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nlohmann/json_fwd.hpp"
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Serializer : public Functor<string()> {
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static constexpr int VERSION{1};
|
||||||
|
std::unique_ptr<json> jPtr;
|
||||||
|
json &j;
|
||||||
|
static int id;
|
||||||
|
|
||||||
|
string visit_(const Constant &c) override;
|
||||||
|
string visit_(const BinaryOp &c) override;
|
||||||
|
string visit_(const RangeOp &c) override;
|
||||||
|
string visit_(const Subscript &c) override;
|
||||||
|
string visit_(const Var &c) override;
|
||||||
|
string visit_(const Tensor &c) override;
|
||||||
|
string dispatchRoutine(const Routine &c);
|
||||||
|
|
||||||
|
Expr buildExprTree(string key);
|
||||||
|
Routine buildRoutine(string key);
|
||||||
|
|
||||||
|
public:
|
||||||
|
Serializer(int _verobse = 0);
|
||||||
|
virtual ~Serializer();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Serialize the given expression to json file
|
||||||
|
*
|
||||||
|
* @param expr The expression to be serialized
|
||||||
|
* @param filePath The path of json file to be output
|
||||||
|
* @param msg Message of derivation
|
||||||
|
* @return bool Whether the serialization succeed
|
||||||
|
*/
|
||||||
|
bool serialize(const Expr &expr, const string &filePath,
|
||||||
|
const string &msg = "");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Deserialize the given json file to expression
|
||||||
|
*
|
||||||
|
* @param filePath The path to file to be deserialized
|
||||||
|
* @return Expression deserialized from the given json file
|
||||||
|
*/
|
||||||
|
Expr deserialize(const string &filePath);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Visitor/StrideVisitor.h"
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// Simplify a index expression tree
|
||||||
|
class SimplifyExprVisitor : public Functor<void(optional<int> stride)> {
|
||||||
|
private:
|
||||||
|
SubexprSride subexprStride;
|
||||||
|
int constant;
|
||||||
|
PtrMap<Iterator, int> strides; // [var]=strides
|
||||||
|
|
||||||
|
map<pair<Iterator, int>, int, RefValueLess<pair<Iterator, int>>> divStrides,
|
||||||
|
modStrides; // 3*(i%8): [<i,8>]=3
|
||||||
|
|
||||||
|
// For divde and modulo with expr as dividend: 3*((i+1)%8): [<i+1,8>]=3
|
||||||
|
map<pair<Expr, int>, int, RefAddrLess<pair<Expr, int>>> divExprStrides,
|
||||||
|
modExprStrides;
|
||||||
|
|
||||||
|
public:
|
||||||
|
SimplifyExprVisitor() : Functor(0) {}
|
||||||
|
void visit_(const BinaryOp &c, optional<int> stride) override;
|
||||||
|
void visit_(const Var &c, optional<int> stride) override;
|
||||||
|
void visit_(const Constant &c, optional<int> stride) override;
|
||||||
|
PtrMap<Iterator, int> getStrides(const Expr &expr);
|
||||||
|
// TODO [refactor]: move this to SimplifyFormulaMutator as a member func
|
||||||
|
// this class should be get coefficients in a expr
|
||||||
|
Expr simplify(const Expr &expr);
|
||||||
|
int getConstant(const Expr &expr);
|
||||||
|
pair<PtrMap<Iterator, int>, int> getStridesConstant(const Expr &expr);
|
||||||
|
optional<Range> getExprRange(const Expr &expr, const RangeOp &rangeOp);
|
||||||
|
PtrMap<Iterator, int> getStrides() { return strides; }
|
||||||
|
const auto &getDivStrides() { return divStrides; }
|
||||||
|
const auto &getModStrides() { return modStrides; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/Visitor/StrideVisitor.h"
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// Simplify all indexes in subscripts in an expression tree
|
||||||
|
class SimplifyFormulaMutator : public Mutator {
|
||||||
|
int nSubscripts = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
SimplifyFormulaMutator() : Mutator(0) {}
|
||||||
|
Expr visit_(const Subscript &c) override;
|
||||||
|
// Expr visit_(const BinaryOp &c) override;
|
||||||
|
Expr simplify(const Expr &expr);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
#include "nnet/visitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
using SubexprSride = map<const ExprNode *, optional<int>>;
|
||||||
|
class StrideVisitor : public Functor<optional<int>(void)> {
|
||||||
|
private:
|
||||||
|
SubexprSride subexprStride;
|
||||||
|
|
||||||
|
public:
|
||||||
|
StrideVisitor(int _verobse = 0) : Functor(_verobse) {}
|
||||||
|
optional<int> visit_(const BinaryOp &c) override;
|
||||||
|
optional<int> visit_(const Subscript &c) override;
|
||||||
|
optional<int> visit_(const Var &c) override;
|
||||||
|
optional<int> visit_(const Constant &c) override;
|
||||||
|
// void visit_(const Tensor &c, const Tensor &tensor) override;
|
||||||
|
|
||||||
|
auto getFormulaStride(const RangeOp &e) {
|
||||||
|
subexprStride.clear();
|
||||||
|
// get the location and stride of each iterator
|
||||||
|
auto mulOp = as<BinaryOpNode>(e->getSummand());
|
||||||
|
// TODO [feature]: support complex index exprs
|
||||||
|
if (!mulOp || mulOp->getOpType() != OpType::Mul)
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
dispatch(mulOp->getLhs());
|
||||||
|
dispatch(mulOp->getRhs());
|
||||||
|
return subexprStride;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] auto getExprStride(const Expr &e) {
|
||||||
|
subexprStride.clear();
|
||||||
|
dispatch(e);
|
||||||
|
return subexprStride;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,77 @@
|
||||||
|
#pragma once
|
||||||
|
#include "dbg.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <optional>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
using std::dynamic_pointer_cast;
|
||||||
|
using std::endl;
|
||||||
|
using std::list;
|
||||||
|
using std::make_pair;
|
||||||
|
using std::make_shared;
|
||||||
|
using std::make_tuple;
|
||||||
|
using std::map;
|
||||||
|
using std::max;
|
||||||
|
using std::min;
|
||||||
|
using std::nullopt;
|
||||||
|
using std::optional;
|
||||||
|
using std::pair;
|
||||||
|
using std::set;
|
||||||
|
using std::shared_ptr;
|
||||||
|
using std::string;
|
||||||
|
using std::tie;
|
||||||
|
using std::to_string;
|
||||||
|
using std::tuple;
|
||||||
|
using std::unique_ptr;
|
||||||
|
using std::unordered_map;
|
||||||
|
template <typename T> using uset = std::unordered_set<T>;
|
||||||
|
using std::vector;
|
||||||
|
using std::weak_ptr;
|
||||||
|
|
||||||
|
// Aliases
|
||||||
|
using dtype = float;
|
||||||
|
using HashType = int;
|
||||||
|
|
||||||
|
template <typename T> struct ptr_less {
|
||||||
|
bool operator()(const T &lhs, const T &rhs) const { return *lhs < *rhs; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> struct ptr_hash {
|
||||||
|
size_t operator()(const T &lhs) const {
|
||||||
|
return std::hash<decltype(*lhs)>()(*lhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> struct ptr_equal {
|
||||||
|
bool operator()(const T &lhs, const T &rhs) const { return *lhs == *rhs; }
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline HashType genhash(HashType a, HashType b) {
|
||||||
|
return (a * 10007 + b + 12345) % 1000000007;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HashType genhash(string s) {
|
||||||
|
HashType ret = 0;
|
||||||
|
for (auto c : s)
|
||||||
|
ret = genhash(ret, c);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define nnet_unimplemented_halt() \
|
||||||
|
{ assert(!"Unimplemented"); }
|
||||||
|
|
||||||
|
#define nnet_unimplemented_continue() \
|
||||||
|
{ dbg("Unimplemented"); }
|
||||||
|
|
||||||
|
#define nnet_assert(expr, msg) assert(((void)(msg), (expr)))
|
||||||
|
|
||||||
|
std::string pointer_to_hex(void *i);
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,858 @@
|
||||||
|
/*****************************************************************************
|
||||||
|
|
||||||
|
dbg(...) macro
|
||||||
|
|
||||||
|
License (MIT):
|
||||||
|
|
||||||
|
Copyright (c) 2019 David Peter <mail@david-peter.de>
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to
|
||||||
|
deal in the Software without restriction, including without limitation the
|
||||||
|
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||||
|
sell copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
|
*****************************************************************************/
|
||||||
|
|
||||||
|
#ifndef DBG_MACRO_DBG_H
|
||||||
|
#define DBG_MACRO_DBG_H
|
||||||
|
|
||||||
|
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
|
||||||
|
#define DBG_MACRO_UNIX
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
#define DBG_MACRO_WINDOWS
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// #ifndef DBG_MACRO_NO_WARNING
|
||||||
|
// #pragma message("WARNING: the 'dbg.h' header is included in your code base")
|
||||||
|
// #endif // DBG_MACRO_NO_WARNING
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <chrono>
|
||||||
|
#include <ctime>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <ios>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef DBG_MACRO_UNIX
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __cplusplus >= 201703L
|
||||||
|
#define DBG_MACRO_CXX_STANDARD 17
|
||||||
|
#elif __cplusplus >= 201402L
|
||||||
|
#define DBG_MACRO_CXX_STANDARD 14
|
||||||
|
#else
|
||||||
|
#define DBG_MACRO_CXX_STANDARD 11
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||||
|
#include <optional>
|
||||||
|
#include <variant>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace dbg {
|
||||||
|
|
||||||
|
#ifdef DBG_MACRO_UNIX
|
||||||
|
inline bool isColorizedOutputEnabled() { return isatty(fileno(stderr)); }
|
||||||
|
#else
|
||||||
|
inline bool isColorizedOutputEnabled() { return true; }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct time {};
|
||||||
|
|
||||||
|
namespace pretty_function {
|
||||||
|
|
||||||
|
// Compiler-agnostic version of __PRETTY_FUNCTION__ and constants to
|
||||||
|
// extract the template argument in `type_name_impl`
|
||||||
|
|
||||||
|
#if defined(__clang__)
|
||||||
|
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
|
||||||
|
static constexpr size_t PREFIX_LENGTH =
|
||||||
|
sizeof("const char *dbg::type_name_impl() [T = ") - 1;
|
||||||
|
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
|
||||||
|
#elif defined(__GNUC__) && !defined(__clang__)
|
||||||
|
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
|
||||||
|
static constexpr size_t PREFIX_LENGTH =
|
||||||
|
sizeof("const char* dbg::type_name_impl() [with T = ") - 1;
|
||||||
|
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
#define DBG_MACRO_PRETTY_FUNCTION __FUNCSIG__
|
||||||
|
static constexpr size_t PREFIX_LENGTH =
|
||||||
|
sizeof("const char *__cdecl dbg::type_name_impl<") - 1;
|
||||||
|
static constexpr size_t SUFFIX_LENGTH = sizeof(">(void)") - 1;
|
||||||
|
#else
|
||||||
|
#error "This compiler is currently not supported by dbg_macro."
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace pretty_function
|
||||||
|
|
||||||
|
// Formatting helpers
|
||||||
|
|
||||||
|
template <typename T> struct print_formatted {
|
||||||
|
static_assert(std::is_integral<T>::value,
|
||||||
|
"Only integral types are supported.");
|
||||||
|
|
||||||
|
print_formatted(T value, int numeric_base)
|
||||||
|
: inner(value), base(numeric_base) {}
|
||||||
|
|
||||||
|
operator T() const { return inner; }
|
||||||
|
|
||||||
|
const char *prefix() const {
|
||||||
|
switch (base) {
|
||||||
|
case 8:
|
||||||
|
return "0o";
|
||||||
|
case 16:
|
||||||
|
return "0x";
|
||||||
|
case 2:
|
||||||
|
return "0b";
|
||||||
|
default:
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T inner;
|
||||||
|
int base;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> print_formatted<T> hex(T value) {
|
||||||
|
return print_formatted<T>{value, 16};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> print_formatted<T> oct(T value) {
|
||||||
|
return print_formatted<T>{value, 8};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> print_formatted<T> bin(T value) {
|
||||||
|
return print_formatted<T>{value, 2};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation of 'type_name<T>()'
|
||||||
|
|
||||||
|
template <typename T> const char *type_name_impl() {
|
||||||
|
return DBG_MACRO_PRETTY_FUNCTION;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> struct type_tag {};
|
||||||
|
|
||||||
|
template <int &...ExplicitArgumentBarrier, typename T>
|
||||||
|
std::string get_type_name(type_tag<T>) {
|
||||||
|
namespace pf = pretty_function;
|
||||||
|
|
||||||
|
std::string type = type_name_impl<T>();
|
||||||
|
return type.substr(pf::PREFIX_LENGTH,
|
||||||
|
type.size() - pf::PREFIX_LENGTH - pf::SUFFIX_LENGTH);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> std::string type_name() {
|
||||||
|
if (std::is_volatile<T>::value) {
|
||||||
|
if (std::is_pointer<T>::value) {
|
||||||
|
return type_name<typename std::remove_volatile<T>::type>() +
|
||||||
|
" volatile";
|
||||||
|
} else {
|
||||||
|
return "volatile " +
|
||||||
|
type_name<typename std::remove_volatile<T>::type>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (std::is_const<T>::value) {
|
||||||
|
if (std::is_pointer<T>::value) {
|
||||||
|
return type_name<typename std::remove_const<T>::type>() + " const";
|
||||||
|
} else {
|
||||||
|
return "const " + type_name<typename std::remove_const<T>::type>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (std::is_pointer<T>::value) {
|
||||||
|
return type_name<typename std::remove_pointer<T>::type>() + "*";
|
||||||
|
}
|
||||||
|
if (std::is_lvalue_reference<T>::value) {
|
||||||
|
return type_name<typename std::remove_reference<T>::type>() + "&";
|
||||||
|
}
|
||||||
|
if (std::is_rvalue_reference<T>::value) {
|
||||||
|
return type_name<typename std::remove_reference<T>::type>() + "&&";
|
||||||
|
}
|
||||||
|
return get_type_name(type_tag<T>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string get_type_name(type_tag<short>) { return "short"; }
|
||||||
|
|
||||||
|
inline std::string get_type_name(type_tag<unsigned short>) {
|
||||||
|
return "unsigned short";
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string get_type_name(type_tag<long>) { return "long"; }
|
||||||
|
|
||||||
|
inline std::string get_type_name(type_tag<unsigned long>) {
|
||||||
|
return "unsigned long";
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string get_type_name(type_tag<std::string>) {
|
||||||
|
return "std::string";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::string get_type_name(type_tag<std::vector<T, std::allocator<T>>>) {
|
||||||
|
return "std::vector<" + type_name<T>() + ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
std::string get_type_name(type_tag<std::pair<T1, T2>>) {
|
||||||
|
return "std::pair<" + type_name<T1>() + ", " + type_name<T2>() + ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... T> std::string type_list_to_string() {
|
||||||
|
std::string result;
|
||||||
|
auto unused = {(result += type_name<T>() + ", ", 0)..., 0};
|
||||||
|
static_cast<void>(unused);
|
||||||
|
|
||||||
|
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||||
|
if constexpr (sizeof...(T) > 0) {
|
||||||
|
#else
|
||||||
|
if (sizeof...(T) > 0) {
|
||||||
|
#endif
|
||||||
|
result.pop_back();
|
||||||
|
result.pop_back();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... T> std::string get_type_name(type_tag<std::tuple<T...>>) {
|
||||||
|
return "std::tuple<" + type_list_to_string<T...>() + ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::string get_type_name(type_tag<print_formatted<T>>) {
|
||||||
|
return type_name<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation of 'is_detected' to specialize for container-like types
|
||||||
|
|
||||||
|
namespace detail_detector {
|
||||||
|
|
||||||
|
struct nonesuch {
|
||||||
|
nonesuch() = delete;
|
||||||
|
~nonesuch() = delete;
|
||||||
|
nonesuch(nonesuch const &) = delete;
|
||||||
|
void operator=(nonesuch const &) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename...> using void_t = void;
|
||||||
|
|
||||||
|
template <class Default, class AlwaysVoid, template <class...> class Op,
|
||||||
|
class... Args>
|
||||||
|
struct detector {
|
||||||
|
using value_t = std::false_type;
|
||||||
|
using type = Default;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Default, template <class...> class Op, class... Args>
|
||||||
|
struct detector<Default, void_t<Op<Args...>>, Op, Args...> {
|
||||||
|
using value_t = std::true_type;
|
||||||
|
using type = Op<Args...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail_detector
|
||||||
|
|
||||||
|
template <template <class...> class Op, class... Args>
|
||||||
|
using is_detected =
|
||||||
|
typename detail_detector::detector<detail_detector::nonesuch, void, Op,
|
||||||
|
Args...>::value_t;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using std::begin;
|
||||||
|
using std::end;
|
||||||
|
#if DBG_MACRO_CXX_STANDARD < 17
|
||||||
|
template <typename T> constexpr auto size(const T &c) -> decltype(c.size()) {
|
||||||
|
return c.size();
|
||||||
|
}
|
||||||
|
template <typename T, std::size_t N>
|
||||||
|
constexpr std::size_t size(const T (&)[N]) {
|
||||||
|
return N;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
using std::size;
|
||||||
|
#endif
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using detect_begin_t = decltype(detail::begin(std::declval<T>()));
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using detect_end_t = decltype(detail::end(std::declval<T>()));
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using detect_size_t = decltype(detail::size(std::declval<T>()));
|
||||||
|
|
||||||
|
template <typename T> struct is_container {
|
||||||
|
static constexpr bool value =
|
||||||
|
is_detected<detect_begin_t, T>::value &&
|
||||||
|
is_detected<detect_end_t, T>::value &&
|
||||||
|
is_detected<detect_size_t, T>::value &&
|
||||||
|
!std::is_same<std::string,
|
||||||
|
typename std::remove_cv<typename std::remove_reference<
|
||||||
|
T>::type>::type>::value;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ostream_operator_t =
|
||||||
|
decltype(std::declval<std::ostream &>() << std::declval<T>());
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct has_ostream_operator : is_detected<ostream_operator_t, T> {};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Helper to dbg(…)-print types
|
||||||
|
template <typename T> struct print_type {};
|
||||||
|
|
||||||
|
template <typename T> print_type<T> type() { return print_type<T>{}; }
|
||||||
|
|
||||||
|
// Forward declarations of "pretty_print"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void pretty_print(std::ostream &stream, const T &value, std::true_type);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void pretty_print(std::ostream &, const T &, std::false_type);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline typename std::enable_if<!detail::is_container<const T &>::value &&
|
||||||
|
!std::is_enum<T>::value,
|
||||||
|
bool>::type
|
||||||
|
pretty_print(std::ostream &stream, const T &value);
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const bool &value);
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const char &value);
|
||||||
|
|
||||||
|
template <typename P>
|
||||||
|
inline bool pretty_print(std::ostream &stream, P *const &value);
|
||||||
|
|
||||||
|
template <typename T, typename Deleter>
|
||||||
|
inline bool pretty_print(std::ostream &stream,
|
||||||
|
std::unique_ptr<T, Deleter> &value);
|
||||||
|
|
||||||
|
// template <typename T>
|
||||||
|
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value);
|
||||||
|
|
||||||
|
template <size_t N>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const char (&value)[N]);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const char *const &value);
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::tuple<Ts...> &value);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::tuple<> &);
|
||||||
|
|
||||||
|
template <> inline bool pretty_print(std::ostream &stream, const time &);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const print_formatted<T> &value);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const print_type<T> &);
|
||||||
|
|
||||||
|
template <typename Enum>
|
||||||
|
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
|
||||||
|
pretty_print(std::ostream &stream, Enum const &value);
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::string &value);
|
||||||
|
|
||||||
|
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::string_view &value);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::pair<T1, T2> &value);
|
||||||
|
|
||||||
|
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::optional<T> &value);
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
inline bool pretty_print(std::ostream &stream,
|
||||||
|
const std::variant<Ts...> &value);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename Container>
|
||||||
|
inline typename std::enable_if<detail::is_container<const Container &>::value,
|
||||||
|
bool>::type
|
||||||
|
pretty_print(std::ostream &stream, const Container &value);
|
||||||
|
|
||||||
|
// Specializations of "pretty_print"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void pretty_print(std::ostream &stream, const T &value, std::true_type) {
|
||||||
|
stream << value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void pretty_print(std::ostream &, const T &, std::false_type) {
|
||||||
|
static_assert(detail::has_ostream_operator<const T &>::value,
|
||||||
|
"Type does not support the << ostream operator");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline typename std::enable_if<!detail::is_container<const T &>::value &&
|
||||||
|
!std::is_enum<T>::value,
|
||||||
|
bool>::type
|
||||||
|
pretty_print(std::ostream &stream, const T &value) {
|
||||||
|
pretty_print(stream, value,
|
||||||
|
typename detail::has_ostream_operator<const T &>::type{});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const bool &value) {
|
||||||
|
stream << std::boolalpha << value;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const char &value) {
|
||||||
|
const bool printable = value >= 0x20 && value <= 0x7E;
|
||||||
|
|
||||||
|
if (printable) {
|
||||||
|
stream << "'" << value << "'";
|
||||||
|
} else {
|
||||||
|
stream << "'\\x" << std::setw(2) << std::setfill('0') << std::hex
|
||||||
|
<< std::uppercase << (0xFF & value) << "'";
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P>
|
||||||
|
inline bool pretty_print(std::ostream &stream, P *const &value) {
|
||||||
|
if (value == nullptr) {
|
||||||
|
stream << "nullptr";
|
||||||
|
} else {
|
||||||
|
stream << value;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Deleter>
|
||||||
|
inline bool pretty_print(std::ostream &stream,
|
||||||
|
std::unique_ptr<T, Deleter> &value) {
|
||||||
|
pretty_print(stream, value.get());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// template <typename T>
|
||||||
|
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value) {
|
||||||
|
// pretty_print(stream, value.get());
|
||||||
|
// stream << " (use_count = " << value.use_count() << ")";
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
template <size_t N>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const char (&value)[N]) {
|
||||||
|
stream << value;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const char *const &value) {
|
||||||
|
stream << '"' << value << '"';
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t Idx> struct pretty_print_tuple {
|
||||||
|
template <typename... Ts>
|
||||||
|
static void print(std::ostream &stream, const std::tuple<Ts...> &tuple) {
|
||||||
|
pretty_print_tuple<Idx - 1>::print(stream, tuple);
|
||||||
|
stream << ", ";
|
||||||
|
pretty_print(stream, std::get<Idx>(tuple));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct pretty_print_tuple<0> {
|
||||||
|
template <typename... Ts>
|
||||||
|
static void print(std::ostream &stream, const std::tuple<Ts...> &tuple) {
|
||||||
|
pretty_print(stream, std::get<0>(tuple));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::tuple<Ts...> &value) {
|
||||||
|
stream << "{";
|
||||||
|
pretty_print_tuple<sizeof...(Ts) - 1>::print(stream, value);
|
||||||
|
stream << "}";
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::tuple<> &) {
|
||||||
|
stream << "{}";
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> inline bool pretty_print(std::ostream &stream, const time &) {
|
||||||
|
using namespace std::chrono;
|
||||||
|
|
||||||
|
const auto now = system_clock::now();
|
||||||
|
const auto us =
|
||||||
|
duration_cast<microseconds>(now.time_since_epoch()).count() % 1000000;
|
||||||
|
const auto hms = system_clock::to_time_t(now);
|
||||||
|
const std::tm *tm = std::localtime(&hms);
|
||||||
|
stream << "current time = " << std::put_time(tm, "%H:%M:%S") << '.'
|
||||||
|
<< std::setw(6) << std::setfill('0') << us;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts decimal integer to binary string
|
||||||
|
template <typename T> std::string decimalToBinary(T n) {
|
||||||
|
const size_t length = 8 * sizeof(T);
|
||||||
|
std::string toRet;
|
||||||
|
toRet.resize(length);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < length; ++i) {
|
||||||
|
const auto bit_at_index_i = static_cast<char>((n >> i) & 1);
|
||||||
|
toRet[length - 1 - i] = bit_at_index_i + '0';
|
||||||
|
}
|
||||||
|
|
||||||
|
return toRet;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool pretty_print(std::ostream &stream,
|
||||||
|
const print_formatted<T> &value) {
|
||||||
|
if (value.inner < 0) {
|
||||||
|
stream << "-";
|
||||||
|
}
|
||||||
|
stream << value.prefix();
|
||||||
|
|
||||||
|
// Print using setbase
|
||||||
|
if (value.base != 2) {
|
||||||
|
stream << std::setw(sizeof(T)) << std::setfill('0')
|
||||||
|
<< std::setbase(value.base) << std::uppercase;
|
||||||
|
|
||||||
|
if (value.inner >= 0) {
|
||||||
|
// The '+' sign makes sure that a uint_8 is printed as a number
|
||||||
|
stream << +value.inner;
|
||||||
|
} else {
|
||||||
|
using unsigned_type = typename std::make_unsigned<T>::type;
|
||||||
|
stream << +(static_cast<unsigned_type>(-(value.inner + 1)) + 1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Print for binary
|
||||||
|
if (value.inner >= 0) {
|
||||||
|
stream << decimalToBinary(value.inner);
|
||||||
|
} else {
|
||||||
|
using unsigned_type = typename std::make_unsigned<T>::type;
|
||||||
|
stream << decimalToBinary<unsigned_type>(
|
||||||
|
static_cast<unsigned_type>(-(value.inner + 1)) + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const print_type<T> &) {
|
||||||
|
stream << type_name<T>();
|
||||||
|
|
||||||
|
stream << " [sizeof: " << sizeof(T) << " byte, ";
|
||||||
|
|
||||||
|
stream << "trivial: ";
|
||||||
|
if (std::is_trivial<T>::value) {
|
||||||
|
stream << "yes";
|
||||||
|
} else {
|
||||||
|
stream << "no";
|
||||||
|
}
|
||||||
|
|
||||||
|
stream << ", standard layout: ";
|
||||||
|
if (std::is_standard_layout<T>::value) {
|
||||||
|
stream << "yes";
|
||||||
|
} else {
|
||||||
|
stream << "no";
|
||||||
|
}
|
||||||
|
stream << "]";
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Enum>
|
||||||
|
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
|
||||||
|
pretty_print(std::ostream &stream, Enum const &value) {
|
||||||
|
using UnderlyingType = typename std::underlying_type<Enum>::type;
|
||||||
|
stream << static_cast<UnderlyingType>(value);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::string &value) {
|
||||||
|
stream << '"' << value << '"';
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||||
|
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::string_view &value) {
|
||||||
|
stream << '"' << std::string(value) << '"';
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::pair<T1, T2> &value) {
|
||||||
|
stream << "{";
|
||||||
|
pretty_print(stream, value.first);
|
||||||
|
stream << ", ";
|
||||||
|
pretty_print(stream, value.second);
|
||||||
|
stream << "}";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool pretty_print(std::ostream &stream, const std::optional<T> &value) {
|
||||||
|
if (value) {
|
||||||
|
stream << '{';
|
||||||
|
pretty_print(stream, *value);
|
||||||
|
stream << '}';
|
||||||
|
} else {
|
||||||
|
stream << "nullopt";
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
inline bool pretty_print(std::ostream &stream,
|
||||||
|
const std::variant<Ts...> &value) {
|
||||||
|
stream << "{";
|
||||||
|
std::visit([&stream](auto &&arg) { pretty_print(stream, arg); }, value);
|
||||||
|
stream << "}";
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename Container>
|
||||||
|
inline typename std::enable_if<detail::is_container<const Container &>::value,
|
||||||
|
bool>::type
|
||||||
|
pretty_print(std::ostream &stream, const Container &value) {
|
||||||
|
stream << "{";
|
||||||
|
const size_t size = detail::size(value);
|
||||||
|
const size_t n = std::min(size_t{10}, size);
|
||||||
|
size_t i = 0;
|
||||||
|
using std::begin;
|
||||||
|
using std::end;
|
||||||
|
for (auto it = begin(value); it != end(value) && i < n; ++it, ++i) {
|
||||||
|
pretty_print(stream, *it);
|
||||||
|
if (i != n - 1) {
|
||||||
|
stream << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (size > n) {
|
||||||
|
stream << ", ...";
|
||||||
|
stream << " size:" << size;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream << "}";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... U> struct last {
|
||||||
|
using type = typename last<U...>::type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> struct last<T> { using type = T; };
|
||||||
|
|
||||||
|
template <typename... T> using last_t = typename last<T...>::type;
|
||||||
|
|
||||||
|
class DebugOutput {
|
||||||
|
public:
|
||||||
|
// Helper alias to avoid obscure type `const char* const*` in signature.
|
||||||
|
using expr_t = const char *;
|
||||||
|
|
||||||
|
DebugOutput(const char *filepath, int line, const char *function_name)
|
||||||
|
: m_use_colorized_output(isColorizedOutputEnabled()) {
|
||||||
|
std::string path = filepath;
|
||||||
|
const std::size_t path_length = path.length();
|
||||||
|
if (path_length > MAX_PATH_LENGTH) {
|
||||||
|
path = ".." +
|
||||||
|
path.substr(path_length - MAX_PATH_LENGTH, MAX_PATH_LENGTH);
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << ansi(ANSI_DEBUG) << "[" << path << ":" << line << " ("
|
||||||
|
<< function_name << ")] " << ansi(ANSI_RESET);
|
||||||
|
m_location = ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
auto print(std::initializer_list<expr_t> exprs,
|
||||||
|
std::initializer_list<std::string> types, T &&...values)
|
||||||
|
-> last_t<T...> {
|
||||||
|
if (exprs.size() != sizeof...(values)) {
|
||||||
|
std::cerr << m_location << ansi(ANSI_WARN)
|
||||||
|
<< "The number of arguments mismatch, please check "
|
||||||
|
"unprotected comma"
|
||||||
|
<< ansi(ANSI_RESET) << std::endl;
|
||||||
|
}
|
||||||
|
return print_impl(exprs.begin(), types.begin(),
|
||||||
|
std::forward<T>(values)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
T &&print_impl(const expr_t *expr, const std::string *type, T &&value) {
|
||||||
|
const T &ref = value;
|
||||||
|
std::stringstream stream_value;
|
||||||
|
const bool print_expr_and_type = pretty_print(stream_value, ref);
|
||||||
|
|
||||||
|
std::stringstream output;
|
||||||
|
output << m_location;
|
||||||
|
if (print_expr_and_type) {
|
||||||
|
output << ansi(ANSI_EXPRESSION) << *expr << ansi(ANSI_RESET)
|
||||||
|
<< " = ";
|
||||||
|
}
|
||||||
|
output << ansi(ANSI_VALUE) << stream_value.str() << ansi(ANSI_RESET);
|
||||||
|
if (print_expr_and_type) {
|
||||||
|
output << " (" << ansi(ANSI_TYPE) << *type << ansi(ANSI_RESET)
|
||||||
|
<< ")";
|
||||||
|
}
|
||||||
|
output << std::endl;
|
||||||
|
std::cerr << output.str();
|
||||||
|
|
||||||
|
return std::forward<T>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... U>
|
||||||
|
auto print_impl(const expr_t *exprs, const std::string *types, T &&value,
|
||||||
|
U &&...rest) -> last_t<T, U...> {
|
||||||
|
print_impl(exprs, types, std::forward<T>(value));
|
||||||
|
return print_impl(exprs + 1, types + 1, std::forward<U>(rest)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *ansi(const char *code) const {
|
||||||
|
if (m_use_colorized_output) {
|
||||||
|
return code;
|
||||||
|
} else {
|
||||||
|
return ANSI_EMPTY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool m_use_colorized_output;
|
||||||
|
|
||||||
|
std::string m_location;
|
||||||
|
|
||||||
|
static constexpr std::size_t MAX_PATH_LENGTH = 20;
|
||||||
|
|
||||||
|
static constexpr const char *const ANSI_EMPTY = "";
|
||||||
|
static constexpr const char *const ANSI_DEBUG = "\x1b[02m";
|
||||||
|
static constexpr const char *const ANSI_WARN = "\x1b[33m";
|
||||||
|
static constexpr const char *const ANSI_EXPRESSION = "\x1b[36m";
|
||||||
|
static constexpr const char *const ANSI_VALUE = "\x1b[01m";
|
||||||
|
static constexpr const char *const ANSI_TYPE = "\x1b[32m";
|
||||||
|
static constexpr const char *const ANSI_RESET = "\x1b[0m";
|
||||||
|
};
|
||||||
|
|
||||||
|
// Identity function to suppress "-Wunused-value" warnings in DBG_MACRO_DISABLE
|
||||||
|
// mode
|
||||||
|
template <typename T> T &&identity(T &&t) { return std::forward<T>(t); }
|
||||||
|
|
||||||
|
template <typename T, typename... U>
|
||||||
|
auto identity(T &&, U &&...u) -> last_t<U...> {
|
||||||
|
return identity(std::forward<U>(u)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dbg
|
||||||
|
|
||||||
|
#ifndef DBG_MACRO_DISABLE
|
||||||
|
|
||||||
|
// Force expanding argument with commas for MSVC, ref:
|
||||||
|
// https://stackoverflow.com/questions/35210637/macro-expansion-argument-with-commas
|
||||||
|
// Note that "args" should be a tuple with parentheses, such as "(e1, e2, ...)".
|
||||||
|
#define DBG_IDENTITY(x) x
|
||||||
|
#define DBG_CALL(fn, args) DBG_IDENTITY(fn args)
|
||||||
|
|
||||||
|
#define DBG_CAT_IMPL(_1, _2) _1##_2
|
||||||
|
#define DBG_CAT(_1, _2) DBG_CAT_IMPL(_1, _2)
|
||||||
|
|
||||||
|
#define DBG_16TH_IMPL(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
|
||||||
|
_14, _15, _16, ...) \
|
||||||
|
_16
|
||||||
|
#define DBG_16TH(args) DBG_CALL(DBG_16TH_IMPL, args)
|
||||||
|
#define DBG_NARG(...) \
|
||||||
|
DBG_16TH( \
|
||||||
|
(__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||||
|
|
||||||
|
// DBG_VARIADIC_CALL(fn, data, e1, e2, ...) => fn_N(data, (e1, e2, ...))
|
||||||
|
#define DBG_VARIADIC_CALL(fn, data, ...) \
|
||||||
|
DBG_CAT(fn##_, DBG_NARG(__VA_ARGS__))(data, (__VA_ARGS__))
|
||||||
|
|
||||||
|
// (e1, e2, e3, ...) => e1
|
||||||
|
#define DBG_HEAD_IMPL(_1, ...) _1
|
||||||
|
#define DBG_HEAD(args) DBG_CALL(DBG_HEAD_IMPL, args)
|
||||||
|
|
||||||
|
// (e1, e2, e3, ...) => (e2, e3, ...)
|
||||||
|
#define DBG_TAIL_IMPL(_1, ...) (__VA_ARGS__)
|
||||||
|
#define DBG_TAIL(args) DBG_CALL(DBG_TAIL_IMPL, args)
|
||||||
|
|
||||||
|
#define DBG_MAP_1(fn, args) DBG_CALL(fn, args)
|
||||||
|
#define DBG_MAP_2(fn, args) fn(DBG_HEAD(args)), DBG_MAP_1(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_3(fn, args) fn(DBG_HEAD(args)), DBG_MAP_2(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_4(fn, args) fn(DBG_HEAD(args)), DBG_MAP_3(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_5(fn, args) fn(DBG_HEAD(args)), DBG_MAP_4(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_6(fn, args) fn(DBG_HEAD(args)), DBG_MAP_5(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_7(fn, args) fn(DBG_HEAD(args)), DBG_MAP_6(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_8(fn, args) fn(DBG_HEAD(args)), DBG_MAP_7(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_9(fn, args) fn(DBG_HEAD(args)), DBG_MAP_8(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_10(fn, args) fn(DBG_HEAD(args)), DBG_MAP_9(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_11(fn, args) fn(DBG_HEAD(args)), DBG_MAP_10(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_12(fn, args) fn(DBG_HEAD(args)), DBG_MAP_11(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_13(fn, args) fn(DBG_HEAD(args)), DBG_MAP_12(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_14(fn, args) fn(DBG_HEAD(args)), DBG_MAP_13(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_15(fn, args) fn(DBG_HEAD(args)), DBG_MAP_14(fn, DBG_TAIL(args))
|
||||||
|
#define DBG_MAP_16(fn, args) fn(DBG_HEAD(args)), DBG_MAP_15(fn, DBG_TAIL(args))
|
||||||
|
|
||||||
|
// DBG_MAP(fn, e1, e2, e3, ...) => fn(e1), fn(e2), fn(e3), ...
|
||||||
|
#define DBG_MAP(fn, ...) DBG_VARIADIC_CALL(DBG_MAP, fn, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define DBG_STRINGIFY_IMPL(x) #x
|
||||||
|
#define DBG_STRINGIFY(x) DBG_STRINGIFY_IMPL(x)
|
||||||
|
|
||||||
|
#define DBG_TYPE_NAME(x) dbg::type_name<decltype(x)>()
|
||||||
|
|
||||||
|
#define dbg(...) \
|
||||||
|
dbg::DebugOutput(__FILE__, __LINE__, __func__) \
|
||||||
|
.print({DBG_MAP(DBG_STRINGIFY, __VA_ARGS__)}, \
|
||||||
|
{DBG_MAP(DBG_TYPE_NAME, __VA_ARGS__)}, __VA_ARGS__)
|
||||||
|
#else
|
||||||
|
#define dbg(...) dbg::identity(__VA_ARGS__)
|
||||||
|
#endif // DBG_MACRO_DISABLE
|
||||||
|
|
||||||
|
#endif // DBG_MACRO_DBG_H
|
|
@ -0,0 +1,156 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "expr.h"
|
||||||
|
#include "iterator_table.h"
|
||||||
|
#include "routine.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class Formula {
|
||||||
|
public:
|
||||||
|
Expr root;
|
||||||
|
const int bfsDepth;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Formula(Expr _root, int _bfsDepth) : root(_root), bfsDepth(_bfsDepth) {}
|
||||||
|
string toReadable() const;
|
||||||
|
friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
|
||||||
|
ios << expr.toReadable();
|
||||||
|
return ios;
|
||||||
|
}
|
||||||
|
bool isVariable() const { return as<VarNode>(root) != nullptr; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class MultiFormulas {
|
||||||
|
public:
|
||||||
|
VecExpr roots;
|
||||||
|
const int bfsDepth;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MultiFormulas(VecExpr roots, int _bfsDepth)
|
||||||
|
: roots(roots), bfsDepth(_bfsDepth) {}
|
||||||
|
// string toReadable() const;
|
||||||
|
// friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
|
||||||
|
// ios << expr.toReadable();
|
||||||
|
// return ios;
|
||||||
|
// }
|
||||||
|
};
|
||||||
|
|
||||||
|
class Derivator {
|
||||||
|
public:
|
||||||
|
enum class LogMode { Normal, DumpFristCandiate, NoLog };
|
||||||
|
enum class PassMode { Debug, Full };
|
||||||
|
|
||||||
|
private:
|
||||||
|
list<Formula> candidates;
|
||||||
|
const int maxDepth;
|
||||||
|
int nIteratorNames = 0;
|
||||||
|
int nTensorNames = 0;
|
||||||
|
vector<vector<int>> rulesOverall;
|
||||||
|
enum class Strategy { DFS, Rule, RuleAndDFS } searchStrategy;
|
||||||
|
LogMode logMode;
|
||||||
|
PassMode passMode;
|
||||||
|
bool enableEquivalenceCheck = false;
|
||||||
|
string logFnPrefix;
|
||||||
|
const bool enableHashPruning;
|
||||||
|
int searchedMaxDepth = 0;
|
||||||
|
RoutineType targetOp = RoutineType::NoneType;
|
||||||
|
map<int, vector<Var>> substituteRules;
|
||||||
|
|
||||||
|
vector<int> cntAppliedRules;
|
||||||
|
int cntRule3 = 0;
|
||||||
|
std::unordered_set<HashType> visited;
|
||||||
|
VecExpr intermediateStates;
|
||||||
|
vector<string> ruleStates, ruleMsgs;
|
||||||
|
int cntStates = 0; // the number of intermediate states
|
||||||
|
int searchState = 0; // search state in guided search
|
||||||
|
|
||||||
|
public:
|
||||||
|
Derivator(int maxDepth = 8, bool enableHashPruning = true,
|
||||||
|
LogMode mode = LogMode::NoLog,
|
||||||
|
PassMode passMode = PassMode::Debug);
|
||||||
|
void search(Formula &origin, int depth);
|
||||||
|
void ruleBasedDFS(Formula &origin, int depth, vector<int> _rules,
|
||||||
|
map<int, vector<Var>> _substituteRules = {},
|
||||||
|
bool searchAfterRules = false);
|
||||||
|
void guidedSearch(Formula &origin, int depth);
|
||||||
|
void print();
|
||||||
|
int getNumCandidates() const { return candidates.size(); }
|
||||||
|
const auto &getCandidates() const { return candidates; }
|
||||||
|
void appendCanddiate(const Tensor &tensor, int depth);
|
||||||
|
int getSearchedMaxDepth() const { return searchedMaxDepth; };
|
||||||
|
bool stageCombination(MultiFormulas &origin, int depth);
|
||||||
|
bool checkOOB(const RangeOp &rangeOp, bool halt = true);
|
||||||
|
|
||||||
|
string newTensorName();
|
||||||
|
Var getNewVar();
|
||||||
|
|
||||||
|
Expr mergeMemboundStages(VecExpr stages);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void dfs(Formula &origin, int depth);
|
||||||
|
void ruleBasedDerivate(Formula &origin, int depth);
|
||||||
|
|
||||||
|
void rule1VariableSplit(Formula &origin, int depth, Expr &rCur);
|
||||||
|
void rule2VariableMerging(Formula &origin, int depth, Expr &rCur);
|
||||||
|
void rule3StageSplit(Formula &origin, int dfsDepth, Expr &rCur);
|
||||||
|
void rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
|
||||||
|
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
|
||||||
|
bool mergeStageWithCalc = false);
|
||||||
|
void rule6KenerlMatching(Formula &origin, int depth, Expr &rCur);
|
||||||
|
void rule7DLT(Formula &origin, int depth, Expr &rCur);
|
||||||
|
// Rule 8: guidedDLT
|
||||||
|
void rule8GuidedDLT(Formula &origin, int depth, Expr &rCur);
|
||||||
|
void rule9RangeMagnify(Formula &origin, int depth, Expr &rCur);
|
||||||
|
void rule90TwoStageElementWise(Formula &origin, int depth, Expr &rCur);
|
||||||
|
void rule91MergeStagesWithSum(Formula &origin, int depth, Expr &rCur);
|
||||||
|
/**
|
||||||
|
* @brief For searchState=2, wrap the RangeOp to add offset, if the boundary
|
||||||
|
* does not start from 0. Then match the inner offset RangeOp.
|
||||||
|
*/
|
||||||
|
void matchComputationKernel(Formula &origin, int depth, Expr &rcur);
|
||||||
|
/**
|
||||||
|
* @brief For searchState=3, the Formula must be a MemBound kernel?
|
||||||
|
*/
|
||||||
|
void matchMemBoundKernel(Formula &origin, int depth, Expr &rcur);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check the equivalence for exprs in intermediateStates.
|
||||||
|
*/
|
||||||
|
void checkDerivationEquivalence();
|
||||||
|
|
||||||
|
public:
|
||||||
|
void pushIntermediateState(const Expr &expr);
|
||||||
|
void pushRuleState(const string &state);
|
||||||
|
void pushRuleMsg(const string &state);
|
||||||
|
void popIntermediateState();
|
||||||
|
void popRuleState();
|
||||||
|
void popRuleMsg();
|
||||||
|
// void pushTransformInfo(const Expr &expr, const string &state,
|
||||||
|
// const string &msg);
|
||||||
|
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur);
|
||||||
|
|
||||||
|
RoutineType getTargetOp();
|
||||||
|
void setTargetOp(RoutineType _targetOp);
|
||||||
|
|
||||||
|
int getSearchState();
|
||||||
|
void setSearchState(int _searchState);
|
||||||
|
int getNumIntermediateStates();
|
||||||
|
void printStatistics();
|
||||||
|
void printIntermediateStates();
|
||||||
|
/**
|
||||||
|
* @brief Enable dumping the first results. Verification is enabled so it is
|
||||||
|
* slow.
|
||||||
|
*
|
||||||
|
* @param _logFnPrefix Prefix of output filename
|
||||||
|
*/
|
||||||
|
void setDumpFirstSuccess(const string &_logFnPrefix);
|
||||||
|
void setEquivalenceCheck();
|
||||||
|
PassMode getPassMode();
|
||||||
|
LogMode getLogMode();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,53 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "expr.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// enum class DLTType { Split, Merge, Reorder };
|
||||||
|
|
||||||
|
struct DLTOperation {
|
||||||
|
// DLTType type;
|
||||||
|
virtual ~DLTOperation() {}
|
||||||
|
};
|
||||||
|
struct DLTSplit : DLTOperation {
|
||||||
|
int dim, factor;
|
||||||
|
DLTSplit(int _dim, int _factor) : dim(_dim), factor(_factor) {}
|
||||||
|
};
|
||||||
|
struct DLTMerge : DLTOperation {
|
||||||
|
int dim0, dim1;
|
||||||
|
DLTMerge(int _dim0, int _dim1) : dim0(_dim0), dim1(_dim1) {}
|
||||||
|
};
|
||||||
|
struct DLTReorder : DLTOperation {
|
||||||
|
vector<int> dims;
|
||||||
|
DLTReorder(vector<int> _dims) : dims(_dims) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DLT {
|
||||||
|
vector<Ref<DLTOperation>> ops;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief dim -> (dim/factor, factor)
|
||||||
|
*/
|
||||||
|
void split(int dim, int factor);
|
||||||
|
/**
|
||||||
|
* @brief Merge dim1 into dim0 -> (dim0, dim1)
|
||||||
|
*/
|
||||||
|
void merge(int dim0, int dim1);
|
||||||
|
/**
|
||||||
|
* @brief
|
||||||
|
*
|
||||||
|
* @param dims dims[new_dim]=old_dim
|
||||||
|
*/
|
||||||
|
void reorder(vector<int> dims);
|
||||||
|
optional<Expr> apply(const RangeOp &rangeOp, const Subscript &subscript,
|
||||||
|
string newTensorName);
|
||||||
|
|
||||||
|
private:
|
||||||
|
optional<pair<Expr, Expr>> splitIndex(Expr expr, int factor,
|
||||||
|
RangeOp rangeOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,416 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "ref.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <numeric>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class ExprNode;
|
||||||
|
class VarNode;
|
||||||
|
class TensorNode;
|
||||||
|
class OperatorNode;
|
||||||
|
class RangeOpNode;
|
||||||
|
class SubscriptNode;
|
||||||
|
class BinaryOpNode;
|
||||||
|
class ConstantNode;
|
||||||
|
class FuncNode;
|
||||||
|
using Expr = Ref<ExprNode>;
|
||||||
|
using Var = Ref<VarNode>;
|
||||||
|
using Tensor = Ref<TensorNode>;
|
||||||
|
using Operator = Ref<OperatorNode>;
|
||||||
|
using RangeOp = Ref<RangeOpNode>;
|
||||||
|
using Subscript = Ref<SubscriptNode>;
|
||||||
|
using BinaryOp = Ref<BinaryOpNode>;
|
||||||
|
using Constant = Ref<ConstantNode>;
|
||||||
|
using Func = Ref<FuncNode>;
|
||||||
|
|
||||||
|
class RoutineNode;
|
||||||
|
using Routine = Ref<RoutineNode>;
|
||||||
|
enum class RoutineType {
|
||||||
|
NoneType = 100,
|
||||||
|
MatmulNodeType,
|
||||||
|
ConvNodeType,
|
||||||
|
G2bmmNodeType,
|
||||||
|
GbmmNodeType,
|
||||||
|
ElementWiseNodeType // unmatchable
|
||||||
|
};
|
||||||
|
constexpr inline int MatchableRoutineTypeCnt = 4;
|
||||||
|
constexpr inline int RoutineTypeCnt = MatchableRoutineTypeCnt + 1;
|
||||||
|
inline RoutineType idToRoutineType(int i) {
|
||||||
|
return static_cast<RoutineType>(i + 1 +
|
||||||
|
static_cast<int>(RoutineType::NoneType));
|
||||||
|
}
|
||||||
|
inline int routineTypeToId(const RoutineType &routineType) {
|
||||||
|
return static_cast<int>(routineType) -
|
||||||
|
static_cast<int>(RoutineType::NoneType) - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
using VecExpr = vector<Expr>;
|
||||||
|
|
||||||
|
// common data structure
|
||||||
|
using Iterator = Var; // RE: remove this alias
|
||||||
|
template <typename T, typename U> using PtrMap = std::map<T, U, ptr_less<T>>;
|
||||||
|
template <typename T, typename U>
|
||||||
|
// When keys are pointers, compare keys according to its value instead of
|
||||||
|
// address Specially, the name of Var are compared due to the overload of op=
|
||||||
|
// and hash.
|
||||||
|
using PtrUmap = std::unordered_map<T, U, ptr_hash<T>, ptr_equal<T>>;
|
||||||
|
template <typename T>
|
||||||
|
using PtrUset = std::unordered_set<T, ptr_hash<T>, ptr_equal<T>>;
|
||||||
|
using Appearance = PtrMap<Var, vector<pair<Tensor, int>>>;
|
||||||
|
using StrideTable =
|
||||||
|
PtrMap<Var, vector<tuple<TensorNode *, int, int>>>; // Tensor, dim, stride
|
||||||
|
|
||||||
|
// AST node opeartor
|
||||||
|
bool operator==(const Var &lhs, const string &rhs);
|
||||||
|
bool operator==(const string &lhs, const Var &rhs);
|
||||||
|
Expr operator+(const Expr &lhs, const Expr &rhs);
|
||||||
|
BinaryOp operator-(const Expr &lhs, const Expr &rhs);
|
||||||
|
BinaryOp operator*(const Expr &lhs, const Expr &rhs);
|
||||||
|
BinaryOp operator/(const Expr &lhs, const Expr &rhs);
|
||||||
|
BinaryOp operator%(const Expr &lhs, const Expr &rhs);
|
||||||
|
|
||||||
|
Expr operator+(const Expr &lhs, const int &rhs);
|
||||||
|
Expr operator+(const int &lhs, const Expr &rhs);
|
||||||
|
Expr operator-(const Expr &lhs, const int &rhs);
|
||||||
|
Expr operator-(const int &lhs, const Expr &rhs);
|
||||||
|
Expr operator*(const Expr &lhs, const int &rhs);
|
||||||
|
Expr operator*(const int &lhs, const Expr &rhs);
|
||||||
|
Expr operator%(const Expr &lhs, const int rhs);
|
||||||
|
Expr operator/(const Expr &lhs, const int rhs);
|
||||||
|
|
||||||
|
string serializeVec(vector<Expr> v);
|
||||||
|
string serializeVec(vector<Var> v);
|
||||||
|
template <typename T> inline string serializeVec(vector<T> v) {
|
||||||
|
if (v.empty())
|
||||||
|
return "[]";
|
||||||
|
return "[" +
|
||||||
|
std::accumulate(
|
||||||
|
v.begin() + 1, v.end(), to_string(v[0]),
|
||||||
|
[](const string &a, int b) { return a + ',' + to_string(b); }) +
|
||||||
|
"]";
|
||||||
|
}
|
||||||
|
|
||||||
|
// For RTTI and visitor pattern
|
||||||
|
enum class NodeType {
|
||||||
|
ConstantNodeType,
|
||||||
|
BinaryOpNodeType,
|
||||||
|
RangeOpNodeType,
|
||||||
|
SubscriptNodeType,
|
||||||
|
TensorNodeType,
|
||||||
|
VarNodeType,
|
||||||
|
FuncNodeType
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class FuncType { Relu, Tanh };
|
||||||
|
|
||||||
|
#define DEFINE_GETTYPE(CLASS) \
|
||||||
|
NodeType getType() const override { return NodeType::CLASS##Type; }
|
||||||
|
|
||||||
|
class ExprNode {
|
||||||
|
public:
|
||||||
|
virtual ~ExprNode() {}
|
||||||
|
ExprNode &operator=(const ExprNode &rhs) = delete;
|
||||||
|
|
||||||
|
virtual HashType hash() const = 0; // RE: remove?
|
||||||
|
virtual string toReadable() const = 0;
|
||||||
|
friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr);
|
||||||
|
|
||||||
|
virtual NodeType getType() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class VarNode : public ExprNode {
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
public:
|
||||||
|
VarNode(std::string _name) : name(_name){};
|
||||||
|
virtual ~VarNode() {}
|
||||||
|
DEFINE_GETTYPE(VarNode);
|
||||||
|
|
||||||
|
const std::string &getName() const { return name; }
|
||||||
|
HashType hash() const override { return genhash(name); };
|
||||||
|
string toReadable() const override { return name; };
|
||||||
|
bool equal(const Var &rhs) const { return name == rhs->getName(); }
|
||||||
|
bool neq(const Var &rhs) const { return !equal(rhs); }
|
||||||
|
bool less(const Var &rhs) const { return name < rhs->getName(); }
|
||||||
|
bool equal(const string &rhs) const { return name == rhs; }
|
||||||
|
bool operator==(const VarNode &rhs) const { return name == rhs.getName(); }
|
||||||
|
bool operator<(const VarNode &rhs) const { return name < rhs.getName(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class TensorType { Input, Weight, Intermediate };
|
||||||
|
|
||||||
|
class TensorNode : public ExprNode {
|
||||||
|
string name;
|
||||||
|
vector<int> shape, paddings;
|
||||||
|
TensorType type;
|
||||||
|
Routine source; // if NO source, then this is a input/weight tensor
|
||||||
|
|
||||||
|
public:
|
||||||
|
TensorNode(string _name, vector<int> _shape, vector<int> _paddings = {},
|
||||||
|
Routine _source = nullptr);
|
||||||
|
virtual ~TensorNode() {}
|
||||||
|
DEFINE_GETTYPE(TensorNode);
|
||||||
|
|
||||||
|
bool operator==(const string &rhs) { return name == rhs; }
|
||||||
|
friend bool operator==(const string &lhs, const TensorNode &rhs) {
|
||||||
|
return lhs == rhs.name;
|
||||||
|
}
|
||||||
|
|
||||||
|
HashType hash() const override { return genhash(name); }
|
||||||
|
string toReadable() const override;
|
||||||
|
string toOutputShape() const;
|
||||||
|
const std::string &getName() const { return name; }
|
||||||
|
std::vector<int> &getPadding() { return paddings; }
|
||||||
|
int getPadding(int i) const { return paddings[i]; }
|
||||||
|
const vector<int> &getPaddings() const { return paddings; }
|
||||||
|
void setPadding(int i, int p) { paddings[i] = p; }
|
||||||
|
const vector<int> &getShape() const { return shape; }
|
||||||
|
int getShape(int i) const { return shape[i]; }
|
||||||
|
int64_t getSize() const;
|
||||||
|
int getDims() const { return shape.size(); }
|
||||||
|
const Routine &getSource() const { return source; }
|
||||||
|
int getData(const Ref<vector<int>> &data, const vector<int> &idx);
|
||||||
|
size_t getOffset(const vector<int> &idx);
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class OpType { Range, Add, Mul, Div, Mod, Sub };
|
||||||
|
const char opSymbols[] = "#+*/%-";
|
||||||
|
|
||||||
|
class OperatorNode : public ExprNode {
|
||||||
|
protected:
|
||||||
|
const OpType opType;
|
||||||
|
VecExpr subExprs;
|
||||||
|
|
||||||
|
public:
|
||||||
|
OperatorNode(OpType _opType) : opType(_opType){};
|
||||||
|
OperatorNode(OpType _opType, VecExpr _subExprs)
|
||||||
|
: opType(_opType), subExprs(_subExprs){};
|
||||||
|
|
||||||
|
int getSubExprsNum() { return subExprs.size(); };
|
||||||
|
const VecExpr &getSubExprs() { return subExprs; }
|
||||||
|
const Expr &getSubExprs(int i) const { return subExprs[i]; }
|
||||||
|
OpType getOpType() const { return opType; };
|
||||||
|
void setOperands(int i, Expr e) { subExprs[i] = e; }
|
||||||
|
};
|
||||||
|
|
||||||
|
using Range = pair<int, int>;
|
||||||
|
using VarRangePair = pair<Var, Range>;
|
||||||
|
inline int getLength(const Range &range) { return range.second - range.first; }
|
||||||
|
struct IterationType {
|
||||||
|
enum { Loop, Sum };
|
||||||
|
constexpr static int NumIterationType = 2;
|
||||||
|
};
|
||||||
|
class RangeOpNode : public OperatorNode {
|
||||||
|
public:
|
||||||
|
enum { Summand, END_POS };
|
||||||
|
constexpr static int Loop = IterationType::Loop;
|
||||||
|
constexpr static int Sum = IterationType::Sum;
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<VarRangePair> vars[IterationType::NumIterationType];
|
||||||
|
vector<int> paddings;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RangeOpNode(Expr _summand) : OperatorNode(OpType::Range, {_summand}){};
|
||||||
|
RangeOpNode(const vector<VarRangePair> &_loopIters,
|
||||||
|
const vector<VarRangePair> &_sumIters, Expr _summand,
|
||||||
|
const vector<int> &paddings)
|
||||||
|
: OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters},
|
||||||
|
paddings(paddings){};
|
||||||
|
DEFINE_GETTYPE(RangeOpNode);
|
||||||
|
|
||||||
|
virtual HashType hash() const override {
|
||||||
|
nnet_unimplemented_halt();
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
string toReadable() const override;
|
||||||
|
const Expr &getSummand() const { return subExprs[Summand]; }
|
||||||
|
const vector<VarRangePair> &getVarRanges(int _index) const {
|
||||||
|
return vars[_index];
|
||||||
|
}
|
||||||
|
const vector<VarRangePair> &getLoopVarRanges() const {
|
||||||
|
return vars[IterationType::Loop];
|
||||||
|
}
|
||||||
|
const vector<VarRangePair> &getSumVarRanges() const {
|
||||||
|
return vars[IterationType::Sum];
|
||||||
|
}
|
||||||
|
int getNumOutputDims() const;
|
||||||
|
bool hasVar(int index, Var name) const;
|
||||||
|
bool hasLoopVar(Var name) const { return hasVar(Loop, name); }
|
||||||
|
bool hasSumVar(Var name) const { return hasVar(Sum, name); }
|
||||||
|
bool hasLoopVar(string name) const {
|
||||||
|
return hasVar(Loop, make_ref<VarNode>(name));
|
||||||
|
}
|
||||||
|
bool hasSumVar(string name) const {
|
||||||
|
return hasVar(Sum, make_ref<VarNode>(name));
|
||||||
|
}
|
||||||
|
int getVarIndex(int type, string name);
|
||||||
|
void setSummand(Expr e) { subExprs[Summand] = e; }
|
||||||
|
void setLoopIterator(const vector<VarRangePair> &vecExpr) {
|
||||||
|
vars[Loop] = vecExpr;
|
||||||
|
}
|
||||||
|
void setSumIterator(const vector<VarRangePair> &vecExpr) {
|
||||||
|
vars[Sum] = vecExpr;
|
||||||
|
}
|
||||||
|
void setIterator(const vector<VarRangePair> &loop,
|
||||||
|
const vector<VarRangePair> &sum) {
|
||||||
|
setLoopIterator(loop);
|
||||||
|
setSumIterator(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
const VarRangePair &getVarRange(int _index, int i) const {
|
||||||
|
return vars[_index][i];
|
||||||
|
}
|
||||||
|
const Var &getLoopVar(int i) const { return vars[Loop][i].first; }
|
||||||
|
Range getRange(const Var &var) const;
|
||||||
|
VarRangePair getVarRange(const Var &var) const;
|
||||||
|
bool hasPaddings() const;
|
||||||
|
int getPaddings(int dim) const;
|
||||||
|
vector<int> getPaddings() const;
|
||||||
|
void setPaddings(vector<int> _paddings);
|
||||||
|
void setVarRange(int _index, int i, VarRangePair pair) {
|
||||||
|
vars[_index][i] = pair;
|
||||||
|
}
|
||||||
|
int64_t getFlops() const;
|
||||||
|
int64_t getInputSize(const RangeOp &self) const;
|
||||||
|
int64_t getOutputSize() const;
|
||||||
|
vector<int> getOutputShape() const;
|
||||||
|
// Including paddings
|
||||||
|
vector<Range> getOutputRanges() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BinaryOpNode : public OperatorNode {
|
||||||
|
enum { LHS, RHS, END_POS };
|
||||||
|
|
||||||
|
public:
|
||||||
|
BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs)
|
||||||
|
: OperatorNode(_opType, {_lhs, _rhs}){};
|
||||||
|
virtual ~BinaryOpNode() {}
|
||||||
|
DEFINE_GETTYPE(BinaryOpNode);
|
||||||
|
|
||||||
|
virtual HashType hash() const override {
|
||||||
|
return genhash((HashType)opType,
|
||||||
|
genhash(subExprs[LHS]->hash(), subExprs[RHS]->hash()));
|
||||||
|
};
|
||||||
|
virtual string toReadable() const override;
|
||||||
|
const Expr &getLhs() const { return getSubExprs(LHS); };
|
||||||
|
const Expr &getRhs() const { return getSubExprs(RHS); };
|
||||||
|
void setLhs(Expr e) { setOperands(LHS, e); };
|
||||||
|
void setRhs(Expr e) { setOperands(RHS, e); };
|
||||||
|
// If Var/constant, use this one
|
||||||
|
optional<pair<Var, int>> getModDivParameter() const;
|
||||||
|
// If (Var+constant)/constant, use this one
|
||||||
|
pair<Expr, int> getModDivExpr() const;
|
||||||
|
bool isSwapable() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConstantNode : public ExprNode {
|
||||||
|
int val;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ConstantNode(int _val) : val(_val){};
|
||||||
|
ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){};
|
||||||
|
virtual ~ConstantNode() {}
|
||||||
|
DEFINE_GETTYPE(ConstantNode);
|
||||||
|
|
||||||
|
int getValue() const { return val; }
|
||||||
|
virtual HashType hash() const override { return genhash(val, 6214587); };
|
||||||
|
virtual string toReadable() const override {
|
||||||
|
string ret;
|
||||||
|
ret += std::to_string(val);
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
class SubscriptNode : public ExprNode {
|
||||||
|
protected:
|
||||||
|
Expr indexed;
|
||||||
|
VecExpr subExprs;
|
||||||
|
|
||||||
|
public:
|
||||||
|
SubscriptNode(Expr _indexed, vector<Expr> _subExprs) : subExprs(_subExprs) {
|
||||||
|
setObject(_indexed);
|
||||||
|
};
|
||||||
|
DEFINE_GETTYPE(SubscriptNode);
|
||||||
|
|
||||||
|
virtual HashType hash() const override {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return -1;
|
||||||
|
};
|
||||||
|
virtual string toReadable() const override;
|
||||||
|
|
||||||
|
size_t getDims() const { return subExprs.size(); }
|
||||||
|
const VecExpr &getIndex() const { return subExprs; }
|
||||||
|
const Expr &getIndex(size_t i) const { return subExprs[i]; }
|
||||||
|
void setIndex(size_t i, Expr e) { subExprs[i] = e; }
|
||||||
|
Expr *getObjectPtr() { return &indexed; }
|
||||||
|
Expr getObject() const { return indexed; }
|
||||||
|
void setObject(Expr e);
|
||||||
|
bool isRangeOpSubscripted() const;
|
||||||
|
bool isTensorSubscripted() const { return !isRangeOpSubscripted(); }
|
||||||
|
// Get the ranges of objects including paddings
|
||||||
|
vector<Range> getObjectRangesWithPaddings() const;
|
||||||
|
vector<Range> getObjectRangesWithoutPaddings() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FuncNode : public ExprNode {
|
||||||
|
protected:
|
||||||
|
Subscript object;
|
||||||
|
FuncType funcType;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FuncNode(Expr object, FuncType funcType) : funcType(funcType) {
|
||||||
|
setObject(object);
|
||||||
|
}
|
||||||
|
DEFINE_GETTYPE(FuncNode);
|
||||||
|
|
||||||
|
virtual HashType hash() const override {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return -1;
|
||||||
|
};
|
||||||
|
virtual string toReadable() const override;
|
||||||
|
|
||||||
|
const Subscript &getObject() const { return object; }
|
||||||
|
void setObject(Expr e);
|
||||||
|
|
||||||
|
FuncType getFuncType() const { return funcType; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrappers for type deduction
|
||||||
|
Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts);
|
||||||
|
RangeOp makeRangeOperator(const vector<VarRangePair> &_loopIters,
|
||||||
|
const vector<VarRangePair> &_sumIters, Expr _summand,
|
||||||
|
const vector<int> &paddings = {});
|
||||||
|
Tensor makeTensor(const string &name, const vector<int> &shape,
|
||||||
|
const vector<int> &paddings = {},
|
||||||
|
const Routine &source = nullptr);
|
||||||
|
|
||||||
|
// Pretty output for dbg with shared_ptr
|
||||||
|
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
|
||||||
|
*_ = nullptr>
|
||||||
|
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
|
||||||
|
os << ((!a) ? string("nullptr") : a->toReadable());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pretty output for dbg with shared_ptr
|
||||||
|
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
|
||||||
|
*_ = nullptr>
|
||||||
|
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
|
||||||
|
os << ((!a) ? string("nullptr") : a->toReadable());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
#undef DEFINE_GETTYPE
|
||||||
|
|
||||||
|
} // namespace nnet
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template <> struct hash<nnet::VarNode &> {
|
||||||
|
size_t operator()(const nnet::VarNode &t) const {
|
||||||
|
return std::hash<string>()(t.getName());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace std
|
|
@ -0,0 +1,234 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "expr.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
using PatternTensorMap = vector<Tensor>;
|
||||||
|
using PatternIterRangeMap = PtrMap<Iterator, VarRangePair>;
|
||||||
|
|
||||||
|
enum class MismatchType {
|
||||||
|
// Search required (undetermined)
|
||||||
|
MoreVar,
|
||||||
|
LessVar,
|
||||||
|
StrideMismatch,
|
||||||
|
// guided DLT (determined)
|
||||||
|
DLMismatch,
|
||||||
|
OutputDLMismatch,
|
||||||
|
OutputDimismatch
|
||||||
|
};
|
||||||
|
struct Mismatch {
|
||||||
|
MismatchType type;
|
||||||
|
int bitmap; // Row ID of IT
|
||||||
|
PtrMap<Iterator, Iterator>
|
||||||
|
mappingIter_r; // For DLT mismatch, iters are mapped
|
||||||
|
Mismatch(MismatchType _type, int _bitmap) : type(_type), bitmap(_bitmap) {}
|
||||||
|
Mismatch(MismatchType _type, int _bitmap,
|
||||||
|
PtrMap<Iterator, Iterator> _mappingIter_r)
|
||||||
|
: type(_type), bitmap(_bitmap), mappingIter_r(_mappingIter_r) {}
|
||||||
|
};
|
||||||
|
class Pattern;
|
||||||
|
class IteratorTable {
|
||||||
|
protected:
|
||||||
|
// using Appearance = map<string, vector<pair<Tensor, int>>>;
|
||||||
|
// using StrideTable = map<TensorNode *, vector<tuple<string, int, int>>>;
|
||||||
|
// // Var, dim, stride
|
||||||
|
RangeOp rangeOp;
|
||||||
|
// To real tensor
|
||||||
|
// FIXME: redundent
|
||||||
|
Appearance appearance;
|
||||||
|
vector<Tensor> tensors; // original tensor sequence
|
||||||
|
vector<Subscript> subscripts; // original subscripts sequence
|
||||||
|
StrideTable strideTable; // TODO [Refactor]: rename strideTable
|
||||||
|
PatternIterRangeMap iterToRange;
|
||||||
|
|
||||||
|
// mapping
|
||||||
|
vector<int> tensorMap; // [index for tensors] -> tensorID in pattern
|
||||||
|
PtrMap<Iterator, Iterator> iterMap; // [expr iter] -> pattern iter
|
||||||
|
|
||||||
|
// final data
|
||||||
|
vector<vector<Iterator>> posTable; // [Tensor bitmap]=[Iterator]
|
||||||
|
vector<vector<vector<Iterator>>>
|
||||||
|
iterInTensorDim; // [tensorID][dimOfTensor]=[Iterator],
|
||||||
|
// stride in each dim may be add
|
||||||
|
|
||||||
|
vector<vector<PtrMap<Iterator, int>>>
|
||||||
|
strideInDim; // [tensorID][dimOfTensor][Iterator]=stride,
|
||||||
|
// stride in each dim may be add
|
||||||
|
|
||||||
|
PtrMap<Iterator, vector<int>> strideInTensor; // [Iterator][tensorID]=stride
|
||||||
|
|
||||||
|
// final data: auxiliary data
|
||||||
|
vector<int> tensorIDMap_r;
|
||||||
|
PatternTensorMap tensorMap_r;
|
||||||
|
PatternIterRangeMap iterToRange_r;
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual ~IteratorTable() {}
|
||||||
|
IteratorTable() {}
|
||||||
|
IteratorTable(const IteratorTable &) = delete;
|
||||||
|
[[nodiscard]] bool analyzeExpr(const RangeOp &rangeOp);
|
||||||
|
// mapTensors
|
||||||
|
void buildTable(const vector<int> &_tensorMap);
|
||||||
|
void buildTableWithDefaultMap();
|
||||||
|
/**
|
||||||
|
* @brief Check whether the expression match a pattern. If not, return the
|
||||||
|
* detailed reason for guided search.
|
||||||
|
*
|
||||||
|
* @param patternIT
|
||||||
|
* @return vector<int> mismatched IT rows/tensors for guided DLT.
|
||||||
|
*/
|
||||||
|
vector<Mismatch> matchPatternIT(const Pattern &patternIT);
|
||||||
|
void matchIterators();
|
||||||
|
int getNumInputs() const { return tensors.size(); }
|
||||||
|
int getNumTensors() const { return tensors.size() + 1; }
|
||||||
|
int getNumRows() const { return 1 << getNumTensors(); }
|
||||||
|
int getNumIterators() const { return strideTable.size(); }
|
||||||
|
// vector<Tensor> tensorMap_r(
|
||||||
|
// pattern.nInputs); // [pattern tensor ID] -> real tensor
|
||||||
|
// map<string, VarRangePair> iterToRange_r; // [pattern iter] -> iter &
|
||||||
|
// range
|
||||||
|
auto getTables() const {
|
||||||
|
return tuple(posTable, iterInTensorDim, strideInTensor);
|
||||||
|
}
|
||||||
|
const auto &getStrideInDim() const { return strideInDim; }
|
||||||
|
vector<vector<Iterator>> getIterInTensorDim(int tensorID) const {
|
||||||
|
return iterInTensorDim[tensorID];
|
||||||
|
}
|
||||||
|
const vector<Iterator> &getPosTable(int bitmap) const {
|
||||||
|
return posTable[bitmap];
|
||||||
|
}
|
||||||
|
pair<PatternTensorMap, PatternIterRangeMap> getReverseMap() const;
|
||||||
|
|
||||||
|
int getStridesInTensor(Iterator iter, int tensorID) const;
|
||||||
|
vector<int> getIterDimInTensor(int tensorID, const Iterator &iter) const;
|
||||||
|
Tensor getTensor(int tensorID) const { return tensorMap_r[tensorID]; }
|
||||||
|
Subscript getSubscript(int tensorID) const {
|
||||||
|
return subscripts[tensorIDMap_r[tensorID]];
|
||||||
|
}
|
||||||
|
Range getIterRange(const Iterator &iter) const {
|
||||||
|
return rangeOp->getRange(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check strides of each iterators and there position in tensors.
|
||||||
|
* Since many-to-many iterators matching exist, we take this procudure as a
|
||||||
|
* seperate function to deal with different iterator mapping solution.
|
||||||
|
*
|
||||||
|
* @param patternIT
|
||||||
|
* @param mappingIter_r
|
||||||
|
* @return vector<Mismatch>
|
||||||
|
*/
|
||||||
|
vector<Mismatch>
|
||||||
|
matchPatternITCheckStrides(const Pattern &patternIT,
|
||||||
|
PtrMap<Iterator, Iterator> mappingIter_r);
|
||||||
|
RangeOp getRangeOp() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StrideConstraint {
|
||||||
|
int tensorID;
|
||||||
|
Var v0, v1;
|
||||||
|
enum class Constraint { SAME, PROPOTIONAL } type;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Pattern : public IteratorTable {
|
||||||
|
vector<StrideConstraint> strideConstraints;
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual Expr
|
||||||
|
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||||
|
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||||
|
string outputName,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT) const = 0;
|
||||||
|
/**
|
||||||
|
* @brief Check whether all indexes only are a iterator
|
||||||
|
*
|
||||||
|
* @param tensorID
|
||||||
|
*/
|
||||||
|
bool isAllUniqueAccess(int tensorID) const;
|
||||||
|
const auto &getStrideConstraints() const { return strideConstraints; };
|
||||||
|
int calcPadding(const Tensor &tensor, int dim, Range rangeH, Range rangeR,
|
||||||
|
int offset) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatmulPattern : public Pattern {
|
||||||
|
public:
|
||||||
|
static const Pattern &getMatmulPattern();
|
||||||
|
static pair<Expr, pair<Tensor, Tensor>> getExpr(bool transA, bool transB,
|
||||||
|
int b, int m, int n, int k);
|
||||||
|
|
||||||
|
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||||
|
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||||
|
string outputName,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConvPattern : public Pattern {
|
||||||
|
private:
|
||||||
|
static const Var n, c, h, w, f, r, s;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static const Pattern &getPattern();
|
||||||
|
static Expr getExpr(Tensor A, Tensor K, int n, int c, int h, int w, int f,
|
||||||
|
int r, int s);
|
||||||
|
|
||||||
|
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||||
|
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||||
|
string outputName,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConvTransPattern : public Pattern {
|
||||||
|
private:
|
||||||
|
static const Var n, c, h, w, f, r, s;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static const Pattern &getPattern() = delete;
|
||||||
|
static Expr getExpr(Tensor A, Tensor K, int N, int C, int H, int W, int F,
|
||||||
|
int R, int S);
|
||||||
|
|
||||||
|
Expr
|
||||||
|
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||||
|
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||||
|
string outputName,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT) const override {
|
||||||
|
nnet_unimplemented_halt();
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
class Sg2bmmPattern : public Pattern {
|
||||||
|
private:
|
||||||
|
static const Var b, m, w, k;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static const Pattern &getPattern();
|
||||||
|
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int K,
|
||||||
|
int W, int D);
|
||||||
|
|
||||||
|
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||||
|
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||||
|
string outputName,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LongformerGBMMPattern : public Pattern {
|
||||||
|
private:
|
||||||
|
static const Var b, m, w, n;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static const Pattern &getPattern();
|
||||||
|
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int W,
|
||||||
|
int K, int dilation);
|
||||||
|
|
||||||
|
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||||
|
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||||
|
string outputName,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
const Pattern &getPattern(RoutineType targetOp);
|
||||||
|
string getPatternName(RoutineType targetOp);
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,57 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/mutator.h"
|
||||||
|
#include "nnet/expr.h"
|
||||||
|
|
||||||
|
#ifdef ABC
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class NMutator : public Mutator {
|
||||||
|
private:
|
||||||
|
// Suffix -N: NNet objects.
|
||||||
|
// Suffix -T: tpm objects.
|
||||||
|
// Map: NNet tensors -> tpm tensor.
|
||||||
|
std::map<std::string, Tensor> inputsNameNToTensorT;
|
||||||
|
enum class Mode { Normal, ToNaiveMembound, RuleBased } mode = Mode::Normal;
|
||||||
|
const double bandwidth = double(200) * 1024 * 1024 * 1024;
|
||||||
|
// If in RuleBased mode, use derivationRules in derivator
|
||||||
|
const std::vector<int> derivationRules;
|
||||||
|
|
||||||
|
public:
|
||||||
|
NMutator();
|
||||||
|
NMutator(const std::vector<int> &derivationRules);
|
||||||
|
~NMutator();
|
||||||
|
|
||||||
|
vector<Graph> run(const Graph &in_graph) override;
|
||||||
|
void setToNaiveMembound();
|
||||||
|
|
||||||
|
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
|
||||||
|
long long cntStates = 0;
|
||||||
|
long long cntCandidates = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int maxDepth = 8;
|
||||||
|
nnet::Expr opToExpression(Operator op);
|
||||||
|
void runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Test helper. Converting a single OP to Membound Op for
|
||||||
|
* corretness check.
|
||||||
|
*/
|
||||||
|
void runSingleOpToNaiveMembound(Graph in_graph,
|
||||||
|
std::vector<Graph> &out_graphs);
|
||||||
|
void runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs);
|
||||||
|
Graph expressionToGraph(nnet::Expr expr, Graph in_graph);
|
||||||
|
Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
|
||||||
|
double memboundTime(ssize_t cnt);
|
||||||
|
double memboundTime(const Shape &dims);
|
||||||
|
|
||||||
|
Graph transformTConv1x1(Operator op);
|
||||||
|
Graph transformTConv3x3(Operator op);
|
||||||
|
Graph transformDialtedConv(Operator op);
|
||||||
|
Graph transformConv1x1(Operator op);
|
||||||
|
Graph transformConv1xk(Operator op);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
#endif
|
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "expr.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class PermutationGenerator {
|
||||||
|
vector<vector<Iterator>> from, to;
|
||||||
|
vector<vector<size_t>> mapping;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PermutationGenerator(vector<vector<Iterator>> _from,
|
||||||
|
vector<vector<Iterator>> _to);
|
||||||
|
bool next();
|
||||||
|
PtrMap<Iterator, Iterator> get() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> class SubsetGenerator {
|
||||||
|
vector<T> elements;
|
||||||
|
int n, bitmap;
|
||||||
|
|
||||||
|
public:
|
||||||
|
SubsetGenerator(vector<T> elements, bool nonEmpty = 1)
|
||||||
|
: elements(elements), n(elements.size()), bitmap((nonEmpty > 0)) {
|
||||||
|
assert(n < 10);
|
||||||
|
};
|
||||||
|
bool next() { return ((++bitmap) < (1 << n) - 1); }
|
||||||
|
vector<T> get() const {
|
||||||
|
vector<T> ret;
|
||||||
|
for (int i = 0; i < n; ++i)
|
||||||
|
if (bitmap & (1 << i))
|
||||||
|
ret.emplace_back(elements[i]);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,200 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include <functional> // hash
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
template <typename T> struct is_ref;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ref-counting pointer
|
||||||
|
*
|
||||||
|
* This class is thread-safe (For developers: concurrent accesses through
|
||||||
|
* different `std::shared_ptr`s to the same object is already thread-safe, while
|
||||||
|
* modifying the same `std::shared_ptr` is not. We never modify a `Ref`, so no
|
||||||
|
* locks are needed. See https://en.cppreference.com/w/cpp/memory/shared_ptr)
|
||||||
|
*/
|
||||||
|
template <class T> class Ref {
|
||||||
|
static_assert(is_ref<T>::value == false, "Ref should not be nested");
|
||||||
|
|
||||||
|
template <class U> friend class Ref;
|
||||||
|
|
||||||
|
std::shared_ptr<T> ptr_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
public:
|
||||||
|
typedef T Object;
|
||||||
|
|
||||||
|
Ref() = default;
|
||||||
|
// Ref(std::nullptr_t) : Ref() {}
|
||||||
|
constexpr Ref(nullptr_t) noexcept : Ref() {}
|
||||||
|
Ref(const Ref &) = default;
|
||||||
|
Ref(Ref &&) = default;
|
||||||
|
Ref(std::shared_ptr<T> &&ptr) : ptr_(std::move(ptr)) {}
|
||||||
|
// Ref(const std::shared_ptr<T> &ptr) : ptr_(ptr) {}
|
||||||
|
|
||||||
|
// /// NO NOT USE THIS CONSTRUCTOR IN PUBLIC
|
||||||
|
// /// It is public because Pybind11 needs it
|
||||||
|
// Ref(T *ptr) : ptr_(ptr) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shared with any compatible references
|
||||||
|
*/
|
||||||
|
template <class U,
|
||||||
|
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
|
||||||
|
Ref(const Ref<U> &other) : ptr_(std::static_pointer_cast<T>(other.ptr_)) {}
|
||||||
|
|
||||||
|
template <class U,
|
||||||
|
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
|
||||||
|
Ref &operator=(const Ref<U> &other) {
|
||||||
|
ptr_ = std::static_pointer_cast<T>(other.ptr_);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref &operator=(const Ref &) = default;
|
||||||
|
Ref &operator=(Ref &&) = default;
|
||||||
|
|
||||||
|
template <class U> Ref<U> as() const {
|
||||||
|
Ref<U> ret;
|
||||||
|
ret.ptr_ = std::dynamic_pointer_cast<U>(ptr_);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isValid() const { return ptr_ != nullptr; }
|
||||||
|
|
||||||
|
T &operator*() const {
|
||||||
|
nnet_assert(isValid(), "Empty pointer.");
|
||||||
|
return *ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
T *operator->() const {
|
||||||
|
nnet_assert(isValid(), "Empty pointer.");
|
||||||
|
return ptr_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
T *get() const {
|
||||||
|
nnet_assert(isValid(), "Empty pointer.");
|
||||||
|
return ptr_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
friend inline bool operator==(const Ref &lhs, nullptr_t) {
|
||||||
|
return !lhs.isValid();
|
||||||
|
}
|
||||||
|
friend inline bool operator!=(const Ref &lhs, nullptr_t) {
|
||||||
|
return !(lhs == nullptr);
|
||||||
|
}
|
||||||
|
explicit operator bool() const { return ptr_ != nullptr; }
|
||||||
|
bool operator!() { return ptr_ == nullptr; }
|
||||||
|
|
||||||
|
void swap(Ref &__b) noexcept { ptr_.swap(__b.ptr_); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class T, class U,
|
||||||
|
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
|
||||||
|
Ref<T> as(const Ref<U> &ref) {
|
||||||
|
return ref.template as<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
|
||||||
|
return Ref(make_shared<T>(std::forward<Params>(params)...));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comparator for Ref
|
||||||
|
template <typename T> struct is_ref : std::false_type {};
|
||||||
|
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
|
||||||
|
|
||||||
|
template <class Tuple, std::size_t index = 0, bool address_based>
|
||||||
|
typename std::enable_if_t<not is_ref<std::tuple_element_t<index, Tuple>>::value,
|
||||||
|
bool>
|
||||||
|
__ref_less(const Tuple &lhs, const Tuple &rhs) {
|
||||||
|
if constexpr (index >=
|
||||||
|
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
|
||||||
|
return std::get<index>(lhs) < std::get<index>(rhs);
|
||||||
|
else {
|
||||||
|
if (std::get<index>(lhs) != std::get<index>(rhs))
|
||||||
|
return std::get<index>(lhs) < std::get<index>(rhs);
|
||||||
|
else
|
||||||
|
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Tuple, std::size_t index = 0, bool address_based>
|
||||||
|
typename std::enable_if_t<is_ref<std::tuple_element_t<index, Tuple>>::value and
|
||||||
|
not address_based,
|
||||||
|
bool>
|
||||||
|
__ref_less(const Tuple &lhs, const Tuple &rhs) {
|
||||||
|
if constexpr (index >=
|
||||||
|
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
|
||||||
|
return std::get<index>(lhs)->less(std::get<index>(rhs));
|
||||||
|
else {
|
||||||
|
if (std::get<index>(lhs)->neq(std::get<index>(rhs)))
|
||||||
|
return std::get<index>(lhs)->less(std::get<index>(rhs));
|
||||||
|
else
|
||||||
|
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Tuple, std::size_t index = 0, bool address_based>
|
||||||
|
typename std::enable_if_t<
|
||||||
|
is_ref<std::tuple_element_t<index, Tuple>>::value and address_based, bool>
|
||||||
|
__ref_less(const Tuple &lhs, const Tuple &rhs) {
|
||||||
|
if constexpr (index >=
|
||||||
|
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
|
||||||
|
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
|
||||||
|
else {
|
||||||
|
if (std::get<index>(lhs).get() != std::get<index>(rhs).get())
|
||||||
|
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
|
||||||
|
else
|
||||||
|
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Tuple> bool ref_addr_less(const Tuple &lhs, const Tuple &rhs) {
|
||||||
|
return __ref_less<Tuple, 0, true>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Tuple> bool ref_value_less(const Tuple &lhs, const Tuple &rhs) {
|
||||||
|
return __ref_less<Tuple, 0, false>(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Tuple> class RefAddrLess {
|
||||||
|
public:
|
||||||
|
bool operator()(const Tuple &a, const Tuple &b) const {
|
||||||
|
return ref_addr_less(a, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Tuple> class RefValueLess {
|
||||||
|
public:
|
||||||
|
bool operator()(const Tuple &a, const Tuple &b) const {
|
||||||
|
return ref_value_less(a, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// make_ref_from_tuple
|
||||||
|
template <typename _Tp, typename _Tuple, size_t... _Idx>
|
||||||
|
constexpr Ref<_Tp> make_ref_from_tuple_impl(_Tuple &&__t,
|
||||||
|
std::index_sequence<_Idx...>) {
|
||||||
|
return make_ref<_Tp>(std::get<_Idx>(std::forward<_Tuple>(__t))...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename _Tp, typename _Tuple>
|
||||||
|
constexpr Ref<_Tp> make_ref_from_tuple(_Tuple &&__t) {
|
||||||
|
return make_ref_from_tuple_impl<_Tp>(
|
||||||
|
std::forward<_Tuple>(__t),
|
||||||
|
std::make_index_sequence<std::tuple_size_v<std::decay_t<_Tuple>>>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
||||||
|
|
||||||
|
// namespace std {
|
||||||
|
|
||||||
|
// template <class T> struct hash<ir::Ref<T>> {
|
||||||
|
// hash<T *> hash_;
|
||||||
|
// size_t operator()(const ir::Ref<T> &ref) const { return hash_(ref.get());
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
// } // namespace nnet
|
|
@ -0,0 +1,158 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "expr.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
class RoutineNode;
|
||||||
|
class MatmulNode;
|
||||||
|
class ElementWiseNode;
|
||||||
|
using Routine = Ref<RoutineNode>;
|
||||||
|
using Matmul = Ref<MatmulNode>;
|
||||||
|
using ElementWise = Ref<ElementWiseNode>;
|
||||||
|
|
||||||
|
#define DEFINE_GETTYPE(CLASS) \
|
||||||
|
RoutineType getType() const override { return RoutineType::CLASS##Type; }
|
||||||
|
|
||||||
|
class RoutineNode {
|
||||||
|
protected:
|
||||||
|
Expr expr;
|
||||||
|
vector<Tensor> inputs;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RoutineNode(Expr _expr, const vector<Tensor> &_inputs);
|
||||||
|
virtual string toReadable() const = 0;
|
||||||
|
const Expr &getExpr() const { return expr; }
|
||||||
|
const vector<Tensor> &getInputs() const { return inputs; }
|
||||||
|
virtual RoutineType getType() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using MatmulArgs = tuple<int, // b
|
||||||
|
int, // m
|
||||||
|
int, // n
|
||||||
|
int, // k
|
||||||
|
bool, // transa
|
||||||
|
bool>; // transb
|
||||||
|
|
||||||
|
class MatmulNode : public RoutineNode {
|
||||||
|
int b, m, n, k;
|
||||||
|
bool transa, transb;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MatmulNode(Expr _source, Tensor A, Tensor B, int _b, int _m, int _n, int _k,
|
||||||
|
bool _transa, bool _transb)
|
||||||
|
: RoutineNode(_source, {A, B}), b(_b), m(_m), n(_n), k(_k),
|
||||||
|
transa(_transa), transb(_transb) {}
|
||||||
|
DEFINE_GETTYPE(MatmulNode);
|
||||||
|
|
||||||
|
string toReadable() const override;
|
||||||
|
|
||||||
|
friend bool operator==(const MatmulNode &lhs, const MatmulNode &rhs);
|
||||||
|
MatmulArgs getArgs() { return tuple(b, m, n, k, transa, transb); }
|
||||||
|
};
|
||||||
|
|
||||||
|
using ConvArgs = tuple<int, // ph
|
||||||
|
int, // pw
|
||||||
|
int, // sh
|
||||||
|
int, // sw
|
||||||
|
int, // dh
|
||||||
|
int>; // dw
|
||||||
|
|
||||||
|
class ConvNode : public RoutineNode {
|
||||||
|
int ph, pw;
|
||||||
|
int sh, sw;
|
||||||
|
int dh, dw;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ConvNode(Expr _source, Tensor A, Tensor K, int _ph, int _pw, int _sh = 1,
|
||||||
|
int _sw = 1, int _dh = 1, int _dw = 1)
|
||||||
|
: RoutineNode(_source, {A, K}), ph(_ph), pw(_pw), sh(_sh), sw(_sw),
|
||||||
|
dh(_dh), dw(_dw) {}
|
||||||
|
DEFINE_GETTYPE(ConvNode);
|
||||||
|
|
||||||
|
string toReadable() const override;
|
||||||
|
vector<int> getShape() const;
|
||||||
|
friend bool operator==(const ConvNode &lhs, const ConvNode &rhs);
|
||||||
|
ConvArgs getArgs() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ElementWiseNode : public RoutineNode {
|
||||||
|
vector<int> outputShape;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// _outputShape is redundent, but expr is still missing for DLT.
|
||||||
|
ElementWiseNode(Expr _source, vector<Tensor> _inputs,
|
||||||
|
vector<int> _outputShape)
|
||||||
|
: RoutineNode(_source, _inputs), outputShape(_outputShape) {}
|
||||||
|
DEFINE_GETTYPE(ElementWiseNode);
|
||||||
|
|
||||||
|
string toReadable() const override;
|
||||||
|
/**
|
||||||
|
* @brief Get the Estimated Time of mem bound OP.
|
||||||
|
*
|
||||||
|
* @return double Time in ms.
|
||||||
|
*/
|
||||||
|
double getEstimatedTime() const;
|
||||||
|
const vector<int> &getOutputShape() const { return outputShape; }
|
||||||
|
};
|
||||||
|
|
||||||
|
using G2bmmArgs = tuple<int, // b
|
||||||
|
int, // m
|
||||||
|
int, // w
|
||||||
|
int, // k
|
||||||
|
int>; // dilation
|
||||||
|
class G2bmmNode : public RoutineNode {
|
||||||
|
int b, m, w, k;
|
||||||
|
|
||||||
|
public:
|
||||||
|
G2bmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int k,
|
||||||
|
int d = 1)
|
||||||
|
: RoutineNode(source, {A, B}), b(b), m(m), w(w), k(k) {
|
||||||
|
assert(d == 1);
|
||||||
|
}
|
||||||
|
DEFINE_GETTYPE(G2bmmNode);
|
||||||
|
|
||||||
|
vector<int> getShape() const;
|
||||||
|
string toReadable() const override;
|
||||||
|
G2bmmArgs getArgs() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
using GbmmArgs = tuple<int, // b
|
||||||
|
int, // m
|
||||||
|
int, // w
|
||||||
|
int, // n
|
||||||
|
int>; // dilation
|
||||||
|
class GbmmNode : public RoutineNode {
|
||||||
|
int b, m, w, n;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GbmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int n,
|
||||||
|
int d = 1)
|
||||||
|
: RoutineNode(source, {A, B}), b(b), m(m), w(w), n(n) {
|
||||||
|
assert(d == 1);
|
||||||
|
}
|
||||||
|
DEFINE_GETTYPE(GbmmNode);
|
||||||
|
|
||||||
|
vector<int> getShape() const;
|
||||||
|
string toReadable() const override;
|
||||||
|
GbmmArgs getArgs() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pretty output for dbg with shared_ptr
|
||||||
|
template <typename T, typename std::enable_if_t<
|
||||||
|
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
|
||||||
|
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
|
||||||
|
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pretty output for dbg with shared_ptr
|
||||||
|
template <typename T, typename std::enable_if_t<
|
||||||
|
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
|
||||||
|
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
|
||||||
|
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,28 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "derivator.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define CAT(A, B) A##B
|
||||||
|
#define SELECT(NAME, NUM) CAT(NAME##_, NUM)
|
||||||
|
#define GET_COUNT( _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ... ) COUNT
|
||||||
|
#define VA_SIZE( ... ) GET_COUNT( __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 )
|
||||||
|
#define VA_SELECT( NAME, ... ) SELECT( NAME, VA_SIZE(__VA_ARGS__) )(__VA_ARGS__)
|
||||||
|
|
||||||
|
#define _DEFVAR_1(name) auto name = make_ref<VarNode>(#name);
|
||||||
|
#define _DEFVAR_2(name, ...) _DEFVAR_1(name); _DEFVAR_1(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_3(name, ...) _DEFVAR_1(name); _DEFVAR_2(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_4(name, ...) _DEFVAR_1(name); _DEFVAR_3(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_5(name, ...) _DEFVAR_1(name); _DEFVAR_4(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_6(name, ...) _DEFVAR_1(name); _DEFVAR_5(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_7(name, ...) _DEFVAR_1(name); _DEFVAR_6(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_8(name, ...) _DEFVAR_1(name); _DEFVAR_7(__VA_ARGS__)
|
||||||
|
#define _DEFVAR_9(name, ...) _DEFVAR_1(name); _DEFVAR_8(__VA_ARGS__)
|
||||||
|
#define DEFINE_VAR(...) VA_SELECT(_DEFVAR, __VA_ARGS__)
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
int matchExprResult(Derivator &derivator, string fn);
|
||||||
|
bool checkExprLogSame(string fnPrefix, int start, int end);
|
||||||
|
bool checkExprsEquvivalence(VecExpr exprs);
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,128 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include "derivator.h"
|
||||||
|
#include "expr.h"
|
||||||
|
#include "routine.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
template <typename FType> class Functor;
|
||||||
|
|
||||||
|
template <typename R, typename... Args> class Functor<R(Args...)> {
|
||||||
|
protected:
|
||||||
|
int verbose;
|
||||||
|
|
||||||
|
// FIXME: scope should be protected
|
||||||
|
public:
|
||||||
|
Functor(int _verobse = 0) : verbose(_verobse) {}
|
||||||
|
virtual ~Functor() = default;
|
||||||
|
#define DISPATCH(CLASS) \
|
||||||
|
case NodeType::CLASS##Type: \
|
||||||
|
return this->visit_(as<CLASS>(c), std::forward<Args>(args)...); \
|
||||||
|
break
|
||||||
|
|
||||||
|
#define FUNCTOR_DEFAULT \
|
||||||
|
{ return visitDefault(c, std::forward<Args>(args)...); }
|
||||||
|
|
||||||
|
virtual R dispatch(const Expr &c, Args... args) {
|
||||||
|
switch (c->getType()) {
|
||||||
|
DISPATCH(ConstantNode);
|
||||||
|
DISPATCH(BinaryOpNode);
|
||||||
|
DISPATCH(RangeOpNode);
|
||||||
|
DISPATCH(SubscriptNode);
|
||||||
|
DISPATCH(TensorNode);
|
||||||
|
DISPATCH(VarNode);
|
||||||
|
DISPATCH(FuncNode);
|
||||||
|
default:
|
||||||
|
nnet_assert(0, "Unknown type");
|
||||||
|
return R();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual R visit_(const Constant &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visit_(const BinaryOp &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visit_(const RangeOp &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visit_(const Subscript &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visit_(const Var &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT;
|
||||||
|
virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) {
|
||||||
|
dbg(*c);
|
||||||
|
nnet_assert(0, "Reach unimplemented visit function.");
|
||||||
|
return R();
|
||||||
|
};
|
||||||
|
|
||||||
|
[[deprecated("Define explicit methods for public access.")]] R
|
||||||
|
operator()(const Expr &e, Args... args) {
|
||||||
|
return dispatch(e, std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
#undef FUNCTOR_DEFAULT
|
||||||
|
#undef DISPATCH
|
||||||
|
};
|
||||||
|
|
||||||
|
class Mutator : public Functor<Expr()> {
|
||||||
|
public:
|
||||||
|
Mutator(int _verobse = 0) : Functor(_verobse) {}
|
||||||
|
Expr visit_(const Constant &c) override;
|
||||||
|
Expr visit_(const BinaryOp &c) override;
|
||||||
|
Expr visit_(const RangeOp &c) override;
|
||||||
|
Expr visit_(const Subscript &c) override;
|
||||||
|
Expr visit_(const Var &c) override;
|
||||||
|
Expr visit_(const Tensor &c) override;
|
||||||
|
Expr visit_(const Func &c) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// template <typename... Args>
|
||||||
|
// class SingleStageVisitor : public Functor<void, Args...> {
|
||||||
|
// public:
|
||||||
|
// SingleStageVisitor(int _verobse = 0) : Functor<R, Args...>(_verobse) {}
|
||||||
|
// // R visit(const Constant &c) override ;
|
||||||
|
// R visit_(const BinaryOp &c) override {
|
||||||
|
// if (verbose)
|
||||||
|
// dbg(*c);
|
||||||
|
// this->dispatch(c->getLhs());
|
||||||
|
// this->dispatch(c->getRhs());
|
||||||
|
// }
|
||||||
|
// R visit_(const RangeOp &c) override {
|
||||||
|
// if (verbose)
|
||||||
|
// dbg(*c);
|
||||||
|
// this->dispatch(ret->getSummand());
|
||||||
|
// // NOT visit iterators and its ranges
|
||||||
|
// }
|
||||||
|
// R visit_(const Subscript &c) override {
|
||||||
|
// if (verbose)
|
||||||
|
// dbg(*c);
|
||||||
|
// this->dispatch(ret->getObject());
|
||||||
|
// for (size_t i = 0; i < ret->getDims(); ++i)
|
||||||
|
// this->dispatch(ret->getIndex(i));
|
||||||
|
// }
|
||||||
|
// // R visit(const Var &c) override;
|
||||||
|
// // R visit(const Tensor &c) override;
|
||||||
|
// };
|
||||||
|
|
||||||
|
// } // namespace nnet
|
||||||
|
// #include "nnet/Visitor/ReplaceVariable.h"
|
||||||
|
// #include "nnet/Visitor/StrideVisitor.h"
|
||||||
|
// namespace nnet {
|
||||||
|
|
||||||
|
class ExprTreeVisitor : public Functor<void(void)> {
|
||||||
|
private:
|
||||||
|
bool inBinary, inRange, inSub, inTensor;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ExprTreeVisitor(bool _inBinary = 1, bool _inRange = 1, bool _inSub = 1,
|
||||||
|
bool _inTensor = 1, int _verobse = 0)
|
||||||
|
: Functor(_verobse), inBinary(_inBinary), inRange(_inRange),
|
||||||
|
inSub(_inSub), inTensor(_inTensor) {}
|
||||||
|
void visit_(const Constant &c) override;
|
||||||
|
void visit_(const BinaryOp &c) override;
|
||||||
|
void visit_(const RangeOp &c) override;
|
||||||
|
void visit_(const Subscript &c) override;
|
||||||
|
void visit_(const Var &c) override;
|
||||||
|
void visit_(const Tensor &c) override;
|
||||||
|
void visit_(const Func &c) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,47 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class MatmulNode : public OperatorNode {
|
||||||
|
private:
|
||||||
|
// InfiniTensor assume a row-major tensor layout. transA=false means default
|
||||||
|
// dims, true means A should be transposed before matmul. This is in
|
||||||
|
// oppsite to column-major BLAS.
|
||||||
|
bool transA, transB;
|
||||||
|
ActType act;
|
||||||
|
|
||||||
|
// Auxiliary attributes
|
||||||
|
int b, m, n, k;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MatmulNode(Tensor A, Tensor B, Tensor C, bool transA = false,
|
||||||
|
bool transB = false, Tensor bias = nullptr,
|
||||||
|
ActType act = ActType::None);
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
vector<Shape> computeShape() const override;
|
||||||
|
|
||||||
|
int numInputs() const override { return 2; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
Tensor getBias() const { return inputs[2]; }
|
||||||
|
ActType getAct() const { return act; }
|
||||||
|
bool getTransA() const { return transA; }
|
||||||
|
bool getTransB() const { return transB; }
|
||||||
|
int getB() const { return b; }
|
||||||
|
int getM() const { return m; }
|
||||||
|
int getN() const { return n; }
|
||||||
|
int getK() const { return k; }
|
||||||
|
|
||||||
|
HashType hashWithShape() const override;
|
||||||
|
OpPerfKey getOpPerfKey() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Q: whether to check the output? Since we can build an Op first and then
|
||||||
|
// assure output.
|
||||||
|
// Fix 1: make shape inference a static method. But OpPerfKey are required.
|
||||||
|
bool checkValid(const TensorVec &inputs) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,3 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "gtest/gtest.h"
|
|
@ -0,0 +1,14 @@
|
||||||
|
#include "core/common.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
double timeit(const std::function<void()> &func) {
|
||||||
|
auto start = std::chrono::high_resolution_clock::now();
|
||||||
|
func();
|
||||||
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
return std::chrono::duration<double, std::milli>(end - start).count();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,20 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void GraphNode::updateConnection() { IT_TODO_HALT(); }
|
||||||
|
|
||||||
|
string GraphNode::toString() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "GraphNode operators:\n";
|
||||||
|
for (const auto &op : ops)
|
||||||
|
oss << op << "\n";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GraphNode::dataMalloc() {
|
||||||
|
for (auto &tensor : tensors)
|
||||||
|
tensor->dataMalloc();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,32 @@
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
bool OperatorNode::isLinearOp() const {
|
||||||
|
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OperatorNode::isElementWiseOp() const {
|
||||||
|
return enum_to_underlying(type) >= 200 && enum_to_underlying(type) < 300;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OperatorNode::isSplitOp() const { return type == OpType::Split; }
|
||||||
|
|
||||||
|
bool OperatorNode::isConcatOp() const { return type == OpType::Concat; }
|
||||||
|
|
||||||
|
bool OperatorNode::isComputeOp() const {
|
||||||
|
return type == OpType::Conv || type == OpType::Matmul ||
|
||||||
|
type == OpType::ConvTrans || type == OpType::G2BMM ||
|
||||||
|
type == OpType::GBMML;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OperatorNode::isTransposeOp() const { return type == OpType::Transpose; }
|
||||||
|
|
||||||
|
bool OperatorNode::isReshapeOp() const { return type == OpType::Reshape; }
|
||||||
|
|
||||||
|
bool OperatorNode::isMemBoundOp() const {
|
||||||
|
return type == OpType::MemBound || type == OpType::Activation ||
|
||||||
|
type == OpType::Transpose;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,105 @@
|
||||||
|
#include "core/run_enigne.h"
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
|
||||||
|
if (!tune && profiling)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||||
|
auto perfEngine = PerfEngine::getInstance();
|
||||||
|
// Statistics
|
||||||
|
double totalTime = 0;
|
||||||
|
std::map<OpType, double> opTime;
|
||||||
|
std::map<OpType, int> opCnt;
|
||||||
|
std::chrono::system_clock::time_point begin, end;
|
||||||
|
|
||||||
|
for (auto &op : graph->getOperators()) {
|
||||||
|
// HACK: set correct data type
|
||||||
|
auto kernelAttrs =
|
||||||
|
KernelAttrs{device, op->getOpType(), DataType::Int32};
|
||||||
|
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||||
|
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||||
|
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||||
|
|
||||||
|
// If no record and disable tuning, run with the default argument
|
||||||
|
if (!perfData && !tune) {
|
||||||
|
kernel->compute(op);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: The copy of record should be eliminated
|
||||||
|
PerfRecord record;
|
||||||
|
// Tune the kernel if there is no record
|
||||||
|
if (!perfData) {
|
||||||
|
record = kernel->tune(op);
|
||||||
|
perfEngine.setPerfData(perfKey, record);
|
||||||
|
} else
|
||||||
|
record = *perfData;
|
||||||
|
|
||||||
|
if (!profiling) {
|
||||||
|
kernel->compute(op, *perfData);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
double t = timeit([&]() { kernel->compute(op, *perfData); });
|
||||||
|
op->print();
|
||||||
|
printf(" op_time %lf\n", t);
|
||||||
|
totalTime += t;
|
||||||
|
opTime[op->getOpType()] += t;
|
||||||
|
opCnt[op->getOpType()]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (profiling)
|
||||||
|
printProfilingData(totalTime, opTime, opCnt);
|
||||||
|
}
|
||||||
|
|
||||||
|
double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
|
||||||
|
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||||
|
auto perfEngine = PerfEngine::getInstance();
|
||||||
|
// Statistics
|
||||||
|
double totalTime = 0;
|
||||||
|
std::map<OpType, double> opTime;
|
||||||
|
std::map<OpType, int> opCnt;
|
||||||
|
|
||||||
|
for (auto &op : graph->getOperators()) {
|
||||||
|
// HACK: set correct data type
|
||||||
|
auto kernelAttrs =
|
||||||
|
KernelAttrs{device, op->getOpType(), DataType::Int32};
|
||||||
|
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||||
|
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||||
|
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||||
|
|
||||||
|
PerfRecord record;
|
||||||
|
// Tune the kernel if there is no record
|
||||||
|
if (!perfData) {
|
||||||
|
record = kernel->tune(op);
|
||||||
|
perfEngine.setPerfData(perfKey, record);
|
||||||
|
} else
|
||||||
|
record = *perfData;
|
||||||
|
|
||||||
|
double t = record.time;
|
||||||
|
totalTime += t;
|
||||||
|
if (profiling) {
|
||||||
|
op->print();
|
||||||
|
printf(" op_time %lf\n", t);
|
||||||
|
opTime[op->getOpType()] += t;
|
||||||
|
opCnt[op->getOpType()]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (profiling)
|
||||||
|
printProfilingData(totalTime, opTime, opCnt);
|
||||||
|
return totalTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunEngine::printProfilingData(double totalTime,
|
||||||
|
const std::map<OpType, double> &opTime,
|
||||||
|
const std::map<OpType, int> &opCnt) const {
|
||||||
|
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
|
||||||
|
for (const auto &[type, t] : opTime) {
|
||||||
|
printf("%11s %3d %7.3f %7.1f %7.3f\n",
|
||||||
|
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
|
||||||
|
t / totalTime * 100, t / opCnt.at(type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,90 @@
|
||||||
|
#include <core/tensor.h>
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TensorNode::TensorNode(const Shape &shape, DataType dtype)
|
||||||
|
: TensorBaseNode(shape.size(), dtype), shape(shape) {}
|
||||||
|
|
||||||
|
void TensorNode::dataMalloc() {
|
||||||
|
IT_ASSERT(data == nullptr);
|
||||||
|
// initialized to zero
|
||||||
|
data.reset(reinterpret_cast<VType *>(calloc(size(), sizeof(VType))));
|
||||||
|
}
|
||||||
|
|
||||||
|
VType TensorNode::getData(const Shape &pos) const {
|
||||||
|
return getData(getOffset(pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
string TensorNode::toString() const {
|
||||||
|
return "TensorNode " + std::to_string(guid);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t TensorNode::getOffset(const Shape &pos) const {
|
||||||
|
auto nDim = pos.size();
|
||||||
|
IT_ASSERT(shape.size() == nDim);
|
||||||
|
if (pos.empty())
|
||||||
|
return 0;
|
||||||
|
for (size_t i = 0; i < nDim; ++i)
|
||||||
|
IT_ASSERT(pos[i] < 0 || pos[i] >= shape[i]);
|
||||||
|
size_t idx = pos[0];
|
||||||
|
size_t dm = 0;
|
||||||
|
while (++dm < nDim)
|
||||||
|
idx = idx * shape[dm] + pos[dm];
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t TensorNode::size() const {
|
||||||
|
size_t ret = 1;
|
||||||
|
for (const auto &d : shape)
|
||||||
|
ret *= d;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorNode::copyData(VType *dptr) {
|
||||||
|
IT_ASSERT(data != nullptr);
|
||||||
|
size_t sz = size();
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t i = 0; i < sz; ++i) {
|
||||||
|
data[i] = dptr[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorNode::printData() const {
|
||||||
|
IT_ASSERT(data != nullptr);
|
||||||
|
std::cout << "Tensor: " << guid << std::endl;
|
||||||
|
auto numDims = shape.size();
|
||||||
|
auto dimSzVec = std::vector<int>(numDims, 1);
|
||||||
|
dimSzVec[numDims - 1] = shape[numDims - 1];
|
||||||
|
for (int i = numDims - 1; i != 0; --i)
|
||||||
|
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
|
||||||
|
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
||||||
|
for (size_t j = 0; j < numDims; ++j) {
|
||||||
|
if (i % dimSzVec[j] == 0) {
|
||||||
|
std::cout << "[";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << data[i];
|
||||||
|
for (size_t j = 0; j < numDims; ++j) {
|
||||||
|
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
|
||||||
|
std::cout << "]";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (i != size() - 1)
|
||||||
|
std::cout << ", ";
|
||||||
|
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TensorNode::equalData(const Tensor &rhs) const {
|
||||||
|
IT_ASSERT(data != nullptr);
|
||||||
|
IT_ASSERT(rhs->data != nullptr);
|
||||||
|
if (shape != rhs->getDims())
|
||||||
|
return false;
|
||||||
|
size_t sz = size();
|
||||||
|
for (size_t i = 0; i < sz; ++i)
|
||||||
|
if (data[i] != rhs->data[i])
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,9 @@
|
||||||
|
#include <core/tensor_base.h>
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TensorBaseNode::TensorBaseNode(int dim, DataType dtype)
|
||||||
|
: dim(dim), dtype(dtype) {}
|
||||||
|
|
||||||
|
VType TensorBaseNode::getData(size_t offset) const { return data[offset]; }
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,38 @@
|
||||||
|
#include "operators/matmul.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <typename T> class NaiveMatmul : public Kernel {
|
||||||
|
void compute(const Operator &_op, const PerfRecord &record) const override {
|
||||||
|
auto op = as<MatmulNode>(_op);
|
||||||
|
T *A = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get());
|
||||||
|
T *B = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get());
|
||||||
|
T *C = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get());
|
||||||
|
IT_ASSERT(op->getTransA() == false && op->getTransB() == false);
|
||||||
|
IT_ASSERT(op->getAct() == ActType::None);
|
||||||
|
IT_ASSERT(op->getB() == 1);
|
||||||
|
const int M = op->getM(), N = op->getN(), K = op->getK();
|
||||||
|
for (int i = 0; i < M; i++) {
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
C[i * N + j] = 0;
|
||||||
|
for (int k = 0; k < K; k++) {
|
||||||
|
C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &op) const override { compute(op, {}); }
|
||||||
|
|
||||||
|
PerfRecord tune(const Operator &op) const override {
|
||||||
|
return PerfRecord{.time = timeit([this, &op]() { compute(op); })};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32,
|
||||||
|
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32,
|
||||||
|
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,25 @@
|
||||||
|
#include "nnet/Pass/MatchComputationKernel.h"
|
||||||
|
#include "nnet/Visitor/PatternMatcher.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
// RE: is this duplicate with Rule6KenerlMatching?
|
||||||
|
void MatchComputationKernel::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
nnet_assert(derivator.getSearchState() == 2, __LINE__);
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
// Build wrapper stages for enforce axis starts from 0
|
||||||
|
PatternMatcher patternMatcher(derivator, cur);
|
||||||
|
cur = patternMatcher.getOffsetCur();
|
||||||
|
|
||||||
|
auto matches = patternMatcher.matchWithPattern(
|
||||||
|
cur, getPattern(derivator.getTargetOp()));
|
||||||
|
matches = patternMatcher.applyWrapper(matches);
|
||||||
|
|
||||||
|
for (auto newCur : matches) {
|
||||||
|
derivator.setSearchState(3);
|
||||||
|
nextStep(origin, depth, rCur, newCur);
|
||||||
|
derivator.setSearchState(2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,23 @@
|
||||||
|
#include "nnet/Pass/MatchMemBoundKernel.h"
|
||||||
|
#include "nnet/Visitor/InputVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void MatchMemBoundKernel::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
// FIXME: Whether the Formula is a Membound OP should be checked.
|
||||||
|
nnet_assert(derivator.getSearchState() == 3, __LINE__);
|
||||||
|
nnet_assert(origin.root.get() == rCur.get(),
|
||||||
|
"Only match the entire formula as a Membound Op");
|
||||||
|
auto rangeOp = as<RangeOpNode>(origin.root);
|
||||||
|
const auto &inputs = InputVisitor().getInputs(rangeOp);
|
||||||
|
auto source =
|
||||||
|
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
|
||||||
|
auto tensor =
|
||||||
|
makeTensor(newTensorName(), rangeOp->getOutputShape(), {}, source);
|
||||||
|
// The original code directly appends candidate. But it seems should be done
|
||||||
|
// by the search.
|
||||||
|
// appendCanddiate(as<TensorNode>(tensor), depth);
|
||||||
|
nextStep(origin, depth, rCur, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,58 @@
|
||||||
|
#include "nnet/Pass/Pass.h"
|
||||||
|
#include "nnet/Visitor/CloneMutator.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
Pass::Pass(Derivator &derivator, const string &passName)
|
||||||
|
: derivator(derivator), passName(passName),
|
||||||
|
enableLogging(derivator.getLogMode() != Derivator::LogMode::NoLog),
|
||||||
|
enableDebug(false) {}
|
||||||
|
|
||||||
|
Pass::~Pass() = default;
|
||||||
|
|
||||||
|
void Pass::setEnableLogging(bool value) { enableLogging = value; }
|
||||||
|
|
||||||
|
void Pass::setEnableDebug(bool value) { enableDebug = value; }
|
||||||
|
|
||||||
|
void Pass::run(Formula &origin, int dfsDepth, Expr &rCur) {
|
||||||
|
initialize(origin, rCur);
|
||||||
|
transform(origin, dfsDepth, rCur);
|
||||||
|
finalize();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pass::initialize(Formula &origin, const Expr &rCur) {}
|
||||||
|
|
||||||
|
void Pass::finalize() {}
|
||||||
|
|
||||||
|
Var Pass::getNewVar() { return derivator.getNewVar(); }
|
||||||
|
|
||||||
|
string Pass::newTensorName() { return derivator.newTensorName(); }
|
||||||
|
|
||||||
|
void Pass::nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
|
||||||
|
const string &ruleMsg) {
|
||||||
|
// push rule action description
|
||||||
|
if (enableLogging) {
|
||||||
|
rCur.swap(newCur);
|
||||||
|
derivator.pushIntermediateState(origin.root);
|
||||||
|
rCur.swap(newCur);
|
||||||
|
derivator.pushRuleState(passName);
|
||||||
|
derivator.pushRuleMsg(ruleMsg);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enableDebug) {
|
||||||
|
// In debug mode, do not recur but save the transformed state
|
||||||
|
transformations.emplace_back(CloneMutator().clone(newCur));
|
||||||
|
} else
|
||||||
|
derivator.nextStep(origin, depth, rCur, newCur);
|
||||||
|
|
||||||
|
// pop rule action description
|
||||||
|
if (enableLogging) {
|
||||||
|
derivator.popIntermediateState();
|
||||||
|
derivator.popRuleState();
|
||||||
|
derivator.popRuleMsg();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const VecExpr &Pass::getTransformations() { return transformations; }
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,134 @@
|
||||||
|
#include "nnet/Pass/Rule1VariableSplit.h"
|
||||||
|
#include "nnet/Visitor/ReplaceVariable.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule1VariableSplit::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
vector<Replace> replaces = getSplitableVar(cur);
|
||||||
|
// for (const auto &replace : replaces)
|
||||||
|
// dbg(replace.oldIters, replace.newIters, replace.psis,
|
||||||
|
// replace.newVarRanges);
|
||||||
|
for (const auto &replace : replaces) {
|
||||||
|
auto replacedSummand = replaceIters(cur->getSummand(), replace);
|
||||||
|
if (!replacedSummand) {
|
||||||
|
// TODO: if a real getMergableExprs is implemented, this case should
|
||||||
|
// be an error. Since the expr should appear in the AST.
|
||||||
|
dbg("Warning: No replacment happens.");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto inner =
|
||||||
|
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
|
||||||
|
// build the outerRange{innerRange}[indexForInner] to do DLT
|
||||||
|
Expr nextCur = nullptr;
|
||||||
|
if (replace.iteratorType == IterationType::Loop) {
|
||||||
|
auto subscriptedInner =
|
||||||
|
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
|
||||||
|
nextCur = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||||
|
} else
|
||||||
|
nextCur = inner;
|
||||||
|
|
||||||
|
string msg = "====== END rule1 VariableSplit: ";
|
||||||
|
dbg(msg, replace.oldIters, replace.newIters, replace.phis,
|
||||||
|
replace.psis);
|
||||||
|
msg = replace.toReadable();
|
||||||
|
nextStep(origin, depth, rCur, nextCur, msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Replace> Rule1VariableSplit::getSplitableVar(const RangeOp &rangeOp) {
|
||||||
|
vector<Replace> ret;
|
||||||
|
// Split strategy
|
||||||
|
vector<int> SumFactors, LoopFactors;
|
||||||
|
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
|
||||||
|
SumFactors = {3};
|
||||||
|
LoopFactors = {4};
|
||||||
|
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
|
||||||
|
SumFactors = {2, 3};
|
||||||
|
// LoopFactors = {3, 4};
|
||||||
|
LoopFactors = {4};
|
||||||
|
} else
|
||||||
|
nnet_unimplemented_halt();
|
||||||
|
|
||||||
|
// Split Sum variable
|
||||||
|
for (const int k : SumFactors) {
|
||||||
|
for (const auto &[var, range] : rangeOp->getSumVarRanges()) {
|
||||||
|
int len = range.second - range.first;
|
||||||
|
auto p1 = getNewVar(); // p1=i/k
|
||||||
|
auto p2 = getNewVar(); // p2=i%k
|
||||||
|
if (len > 10 || len <= k || len % k != 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Range range1, range2;
|
||||||
|
if (range.first < 0) {
|
||||||
|
nnet_unimplemented_halt();
|
||||||
|
// FIXME: this must be ERROR
|
||||||
|
range1.first = range.first / k;
|
||||||
|
range1.second = range1.first + len / k;
|
||||||
|
range2.first = -k / 2;
|
||||||
|
range2.second = range2.first + k;
|
||||||
|
} else if (range.first == 0) {
|
||||||
|
range1.first = 0;
|
||||||
|
range1.second = len / k;
|
||||||
|
range2.first = 0;
|
||||||
|
range2.second = k;
|
||||||
|
} else {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Replace replace{.iteratorType = IterationType::Sum,
|
||||||
|
.oldIters = {var},
|
||||||
|
.newIters = {p1, p2},
|
||||||
|
.phis = {},
|
||||||
|
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
|
||||||
|
.newVarRanges = {{p1, range1}, {p2, range2}}};
|
||||||
|
ret.emplace_back(replace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const int k : LoopFactors) {
|
||||||
|
// Split Loop variable
|
||||||
|
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||||
|
const int len = range.second - range.first;
|
||||||
|
// Debug HACK for dilated SG2BMM
|
||||||
|
if (derivator.getPassMode() == Derivator::PassMode::Debug &&
|
||||||
|
!(var->getName() == "m" && len % k == 0))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Illeagel conditions
|
||||||
|
if (range.second - range.first <= k ||
|
||||||
|
(range.second - range.first) % k != 0)
|
||||||
|
continue;
|
||||||
|
// Unsupport conditions
|
||||||
|
if (range.first != 0)
|
||||||
|
continue;
|
||||||
|
auto p1 = getNewVar(); // p1=i/k
|
||||||
|
auto p2 = getNewVar(); // p2=i%k
|
||||||
|
Range range1(0, len / k);
|
||||||
|
Range range2(0, k);
|
||||||
|
nnet_assert(range1.second > 0 && range2.second > 0,
|
||||||
|
"Empty loop dim");
|
||||||
|
Replace replace{.iteratorType = IterationType::Loop,
|
||||||
|
.oldIters = {var},
|
||||||
|
.newIters = {p1, p2},
|
||||||
|
.phis = {var / 4, var % 4},
|
||||||
|
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
|
||||||
|
.newVarRanges = {{p1, range1}, {p2, range2}}};
|
||||||
|
ret.emplace_back(replace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Rule1VariableSplit::replaceIters(Expr cur, const Replace &replace) {
|
||||||
|
// TODO [feature]: support multiple replacements in one mutator
|
||||||
|
if (replace.oldIters.size() != 1) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto replaceMutator =
|
||||||
|
ReplaceVariable(replace.oldIters.at(0), replace.psis.at(0));
|
||||||
|
auto ret = replaceMutator(cur);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,186 @@
|
||||||
|
#include "nnet/Pass/Rule2VariableMerging.h"
|
||||||
|
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule2VariableMerging::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
// Extract r and s
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
vector<Replace> replaces = getMergableReplaces(cur, depth);
|
||||||
|
// dbg("Start rule2VariableMerging", depth, mergableExprs, *cur);
|
||||||
|
for (const auto &replace : replaces) {
|
||||||
|
if (replace.iteratorType != IterationType::Loop) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// replace vars in summand
|
||||||
|
auto replacedSummand = ReplaceKit::replaceMultipleExprs(
|
||||||
|
cur->getSummand(), replace.oldIters, replace.psis, true);
|
||||||
|
// replace var in rangeOp
|
||||||
|
auto inner =
|
||||||
|
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
|
||||||
|
// If OOB happens, this transformation is skipped
|
||||||
|
if (CheckOOBVisitor().checkRangeOp(inner))
|
||||||
|
continue;
|
||||||
|
// build the outerRange{innerRange}[indexForInner] to do DLT
|
||||||
|
auto subscriptedInner =
|
||||||
|
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
|
||||||
|
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||||
|
|
||||||
|
// next searching step
|
||||||
|
string msg = replace.toReadable();
|
||||||
|
nextStep(origin, depth, rCur, outer, msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Replace> Rule2VariableMerging::getMergableReplaces(RangeOp rangeOp,
|
||||||
|
int depth) {
|
||||||
|
vector<Replace> ret;
|
||||||
|
IteratorTable exprIT;
|
||||||
|
if (!exprIT.analyzeExpr(rangeOp)) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
exprIT.buildTableWithDefaultMap();
|
||||||
|
const auto &strideInAllDim = exprIT.getStrideInDim();
|
||||||
|
|
||||||
|
set<pair<Iterator, Iterator>, RefValueLess<pair<Iterator, Iterator>>>
|
||||||
|
checkedIterPairs{};
|
||||||
|
// strideInAllDim: [tensorID][dimOfTensor][Iterator]=stride
|
||||||
|
for (size_t tensorID = 0; tensorID < strideInAllDim.size(); ++tensorID) {
|
||||||
|
const auto &strideInDimsOfATensor = strideInAllDim[tensorID];
|
||||||
|
for (const PtrMap<Iterator, int> &strideInADim :
|
||||||
|
strideInDimsOfATensor) {
|
||||||
|
for (const auto &it1 : strideInADim) {
|
||||||
|
for (const auto &it2 : strideInADim) {
|
||||||
|
// Backdoor for rule-based search
|
||||||
|
if (substituteRules.count(depth)) {
|
||||||
|
if (substituteRules[depth].at(0)->neq(it1.first))
|
||||||
|
continue;
|
||||||
|
if (substituteRules[depth].at(1)->neq(it2.first))
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!(it1.first->equal(it2.first) &&
|
||||||
|
it1.second == it2.second) &&
|
||||||
|
rangeOp->hasLoopVar(it1.first) &&
|
||||||
|
rangeOp->hasLoopVar(it2.first)) {
|
||||||
|
// 2 iters -> 2 iters
|
||||||
|
if (auto opt = getReplaceMappingTwoLoopIters(rangeOp,
|
||||||
|
it1, it2))
|
||||||
|
ret.emplace_back(*opt);
|
||||||
|
|
||||||
|
// 2 iters -> 1 iter
|
||||||
|
const auto iterPair = pair(it1.first, it2.first);
|
||||||
|
if (!checkedIterPairs.count(iterPair)) {
|
||||||
|
checkedIterPairs.insert(iterPair);
|
||||||
|
if (auto opt = getReplaceMergingTwoLoopIters(
|
||||||
|
rangeOp, it1, it2, exprIT, tensorID))
|
||||||
|
ret.emplace_back(*opt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<Replace> Rule2VariableMerging::getReplaceMergingTwoLoopIters(
|
||||||
|
const RangeOp &rangeOp, pair<Iterator, int> pairA,
|
||||||
|
pair<Iterator, int> pairB, const IteratorTable &exprIT, int tensorID) {
|
||||||
|
// 1*A + sb*B -> C
|
||||||
|
// A=C%sb, B=C/sb
|
||||||
|
// ax+by->z, a=1 or -1
|
||||||
|
// For a>0 and b>0 : x=z%b, y=z/b
|
||||||
|
auto x = pairA.first, y = pairB.first;
|
||||||
|
int a = pairA.second, b = pairB.second;
|
||||||
|
if (abs(a) != 1 || abs(a) * abs(b) <= 0)
|
||||||
|
return {};
|
||||||
|
if (a < 0 && b > 0) { // The only unhandled case
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
// negative substitution happens only if can be totally merged. So if the
|
||||||
|
// variable appears in another index, skip it.
|
||||||
|
if (a < 0 || b < 0) {
|
||||||
|
if (exprIT.getNumInputs() > 1) {
|
||||||
|
if (exprIT.getStridesInTensor(x, 1 - tensorID) != 0)
|
||||||
|
return {};
|
||||||
|
if (exprIT.getStridesInTensor(y, 1 - tensorID) != 0)
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Range rangeX = rangeOp->getVarRange(x).second,
|
||||||
|
rangeY = rangeOp->getVarRange(y).second;
|
||||||
|
if (rangeX.first != 0 || rangeY.first != 0)
|
||||||
|
return {};
|
||||||
|
int lenX = rangeX.second - rangeX.first;
|
||||||
|
if (abs(b) != lenX)
|
||||||
|
return {};
|
||||||
|
auto z = getNewVar();
|
||||||
|
|
||||||
|
Range rangeExpr{0, 1}; // 1 is the open interval compensation
|
||||||
|
auto calcRangeExpr = [&rangeExpr](int stride, const Range &r) {
|
||||||
|
if (stride > 0) {
|
||||||
|
rangeExpr.first += stride * r.first;
|
||||||
|
rangeExpr.second += stride * (r.second - 1);
|
||||||
|
} else {
|
||||||
|
rangeExpr.first += stride * (r.second - 1);
|
||||||
|
rangeExpr.second += stride * r.first;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
calcRangeExpr(a, rangeX);
|
||||||
|
calcRangeExpr(b, rangeY);
|
||||||
|
|
||||||
|
// build the phi/psi for index transformation
|
||||||
|
// phi: j_x=(i_x...), psi: i_x=(j_x...)
|
||||||
|
auto ret = optional<Replace>();
|
||||||
|
ret.emplace();
|
||||||
|
ret->iteratorType = IterationType::Loop;
|
||||||
|
ret->newIters = {z};
|
||||||
|
ret->oldIters = {x, y};
|
||||||
|
ret->phis = {a * x + b * y - rangeExpr.first};
|
||||||
|
// For b < 0, the psis are not an equavalent replace. Since it must be
|
||||||
|
// simplified (z/b and z%b will be merged), the only important thing is
|
||||||
|
// their strides should be mergable. To merge the strides, an extra minus
|
||||||
|
// are introduced if their stride is negative.
|
||||||
|
ret->psis = {a * (z % b) + a * rangeExpr.first, (b > 0 ? 1 : -1) * (z / b)};
|
||||||
|
ret->newVarRanges = {{z, {0, rangeExpr.second - rangeExpr.first}}};
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<Replace>
|
||||||
|
Rule2VariableMerging::getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
|
||||||
|
pair<Iterator, int> pairA,
|
||||||
|
pair<Iterator, int> pairB) {
|
||||||
|
// the first iterator is replaced, the second remains
|
||||||
|
auto i1 = pairA.first, i2 = pairB.first;
|
||||||
|
int sa = pairA.second, sb = pairB.second;
|
||||||
|
// TODO: can be relaxed to sb|sb
|
||||||
|
if (sa != 1 || sb == 0)
|
||||||
|
return {};
|
||||||
|
if (sb < 0) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
Range rangeA = rangeOp->getVarRange(i1).second;
|
||||||
|
Range rangeB = rangeOp->getVarRange(i2).second;
|
||||||
|
auto j1 = getNewVar(), j2 = getNewVar();
|
||||||
|
Range rangeJ1, rangeJ2 = rangeB;
|
||||||
|
assert(pairA.second == 1);
|
||||||
|
rangeJ1.first = rangeA.first + rangeB.first * sb;
|
||||||
|
rangeJ1.second = rangeA.second + (rangeB.second - 1) * sb;
|
||||||
|
// build the phi/psi for index transformation
|
||||||
|
// phi: j_x=(i_x...), psi: i_x=(j_x...)
|
||||||
|
auto ret = optional<Replace>();
|
||||||
|
ret.emplace();
|
||||||
|
ret->iteratorType = IterationType::Loop;
|
||||||
|
ret->newIters = {j1, j2};
|
||||||
|
ret->oldIters = {i1, i2};
|
||||||
|
ret->newVarRanges = {{j1, rangeJ1}, {j2, rangeJ2}};
|
||||||
|
ret->phis = {sa * i1 + sb * i2, i2};
|
||||||
|
ret->psis = {j1 - (sb / sa) * j2, j2};
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,82 @@
|
||||||
|
#include "nnet/Pass/Rule3StageSplit.h"
|
||||||
|
#include "nnet/permutation.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule3StageSplit::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
vector<vector<Iterator>> splitSchemes = getSplitSummationIters(cur);
|
||||||
|
|
||||||
|
for (const auto &varSplit : splitSchemes) {
|
||||||
|
|
||||||
|
bool isSplittable = false;
|
||||||
|
for (const auto &splitted : varSplit)
|
||||||
|
if (cur->hasSumVar(splitted))
|
||||||
|
isSplittable = true;
|
||||||
|
assert(isSplittable);
|
||||||
|
|
||||||
|
const vector<VarRangePair> loopVars = cur->getLoopVarRanges(),
|
||||||
|
sumVars = cur->getSumVarRanges();
|
||||||
|
// move iterators from Sigma to Loop
|
||||||
|
vector<VarRangePair> innerLoopVars, innerSumVars, outerSumVars;
|
||||||
|
VecExpr indexForInner;
|
||||||
|
for (const auto &kv : sumVars) {
|
||||||
|
bool isSplitted = false;
|
||||||
|
for (const auto &iter : varSplit)
|
||||||
|
if (iter == kv.first->getName())
|
||||||
|
isSplitted = true;
|
||||||
|
if (isSplitted) {
|
||||||
|
innerLoopVars.emplace_back(kv);
|
||||||
|
outerSumVars.emplace_back(kv);
|
||||||
|
} else
|
||||||
|
innerSumVars.emplace_back(kv);
|
||||||
|
}
|
||||||
|
innerLoopVars.insert(innerLoopVars.end(), loopVars.begin(),
|
||||||
|
loopVars.end());
|
||||||
|
for (const auto &[var, _] : innerLoopVars)
|
||||||
|
indexForInner.emplace_back(var);
|
||||||
|
|
||||||
|
// if no sum iterator, the stage is redundant
|
||||||
|
assert(!innerSumVars.empty());
|
||||||
|
auto inner =
|
||||||
|
makeRangeOperator(innerLoopVars, innerSumVars, cur->getSummand());
|
||||||
|
auto subscriptedInner = make_ref<SubscriptNode>(inner, indexForInner);
|
||||||
|
auto outer = makeRangeOperator(cur->getLoopVarRanges(), outerSumVars,
|
||||||
|
subscriptedInner);
|
||||||
|
outer->setPaddings(cur->getPaddings());
|
||||||
|
|
||||||
|
// next searching step
|
||||||
|
string msg = "Separate sum iters: " + serializeVec(varSplit);
|
||||||
|
nextStep(origin, depth, rCur, outer, msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<vector<Iterator>>
|
||||||
|
Rule3StageSplit::getSplitSummationIters(RangeOp rangeOp) {
|
||||||
|
// set<string> varSplit = {"r", "s", "i3", "i13"};
|
||||||
|
vector<vector<Iterator>> ret;
|
||||||
|
// Rule-based Hint
|
||||||
|
// vector<vector<Iterator>> heuristics = {{"r", "s"}, {"i3", "i13"}};
|
||||||
|
// for (const auto &iterSet : heuristics) {
|
||||||
|
// bool notExist = false;
|
||||||
|
// for (const auto &iter : iterSet)
|
||||||
|
// if (!rangeOp->hasSumVar(iter))
|
||||||
|
// notExist = true;
|
||||||
|
// if (!notExist)
|
||||||
|
// ret.emplace_back(iterSet);
|
||||||
|
// }
|
||||||
|
// if (!rulesOverall.empty())
|
||||||
|
// return ret;
|
||||||
|
vector<Iterator> sumIters;
|
||||||
|
for (const auto &[iter, range] : rangeOp->getSumVarRanges())
|
||||||
|
sumIters.emplace_back(iter);
|
||||||
|
if (sumIters.size() <= 1)
|
||||||
|
return ret;
|
||||||
|
SubsetGenerator gen(sumIters);
|
||||||
|
do {
|
||||||
|
ret.emplace_back(gen.get());
|
||||||
|
} while (gen.next());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,91 @@
|
||||||
|
#include "nnet/Pass/Rule4StageMerging.h"
|
||||||
|
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||||
|
#include "nnet/Visitor/ReplaceVariable.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule4StageMerging::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
success = rule4StageMerging(origin, depth, rCur, mergeStageWithCalc);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Rule4StageMerging::isSuccessful() { return success; }
|
||||||
|
|
||||||
|
void Rule4StageMerging::setMergeStageWithCalc(bool value) {
|
||||||
|
mergeStageWithCalc = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Rule4StageMerging::rule4StageMerging(Formula &origin, int depth,
|
||||||
|
Expr &rCur, bool mergeStageWithCalc) {
|
||||||
|
auto rangeOp0 = as<RangeOpNode>(rCur);
|
||||||
|
const Subscript &sub0 = as<SubscriptNode>(rangeOp0->getSummand());
|
||||||
|
if (!sub0)
|
||||||
|
return false;
|
||||||
|
const auto &rangeOp1 = as<RangeOpNode>(sub0->getObject());
|
||||||
|
if (!rangeOp1)
|
||||||
|
return false;
|
||||||
|
const auto &sub1 = as<SubscriptNode>(rangeOp1->getSummand());
|
||||||
|
if (!sub1)
|
||||||
|
return false;
|
||||||
|
// merge stage with calculation only when mergeStageWithCalc=true
|
||||||
|
if (!mergeStageWithCalc && !rangeOp1->getSumVarRanges().empty())
|
||||||
|
return false;
|
||||||
|
// Only propogate paddings in perfect nested dimension
|
||||||
|
if (rangeOp1->hasPaddings()) {
|
||||||
|
auto oldTensor = as<TensorNode>(sub1->getObject());
|
||||||
|
if (!oldTensor) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// repalce variables: iters of rangeOp1 repalced by indexes of sub0
|
||||||
|
map<string, pair<Expr, Expr>> varMapping;
|
||||||
|
assert(sub0->getDims() == rangeOp1->getLoopVarRanges().size());
|
||||||
|
for (size_t i = 0; i < sub0->getDims(); ++i) {
|
||||||
|
varMapping[rangeOp1->getLoopVar(i)->getName()] =
|
||||||
|
pair(rangeOp1->getLoopVar(i), sub0->getIndex(i));
|
||||||
|
}
|
||||||
|
ReplaceVariable replaceVariable{varMapping};
|
||||||
|
auto merged = make_ref<RangeOpNode>(*rangeOp0);
|
||||||
|
merged->setSummand(replaceVariable(sub1));
|
||||||
|
// a naive approach to propogate paddings
|
||||||
|
if (rangeOp1->hasPaddings()) {
|
||||||
|
auto oldTensor = as<TensorNode>(sub1->getObject());
|
||||||
|
auto newTensor = make_ref<TensorNode>(*oldTensor);
|
||||||
|
for (int i = 0; i < rangeOp1->getNumOutputDims(); ++i) {
|
||||||
|
if (rangeOp1->getPaddings(i) == 0)
|
||||||
|
continue;
|
||||||
|
auto loopVar = rangeOp1->getLoopVar(i);
|
||||||
|
// FIXME: in fact this var should not appear in other index as well,
|
||||||
|
// which may result in OOB
|
||||||
|
bool findSingleVarAsIndex = false;
|
||||||
|
for (size_t subIndexID = 0; subIndexID < sub1->getDims();
|
||||||
|
++subIndexID) {
|
||||||
|
auto index = sub1->getIndex(subIndexID);
|
||||||
|
if (auto indexVar = as<VarNode>(index);
|
||||||
|
indexVar && (indexVar->equal(loopVar))) {
|
||||||
|
newTensor->setPadding(subIndexID,
|
||||||
|
newTensor->getPadding(subIndexID) +
|
||||||
|
rangeOp1->getPaddings(i));
|
||||||
|
findSingleVarAsIndex = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!findSingleVarAsIndex) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
merged = as<RangeOpNode>(
|
||||||
|
ReplaceNodeMutator().replace(merged, oldTensor.get(), newTensor));
|
||||||
|
assert(merged != nullptr);
|
||||||
|
}
|
||||||
|
// Merge inner stage sums
|
||||||
|
if (!rangeOp1->getSumVarRanges().empty())
|
||||||
|
merged->setSumIterator(rangeOp1->getSumVarRanges());
|
||||||
|
|
||||||
|
// next searching step
|
||||||
|
// if mergeStageWithCalc, depth counts for invocation in rule-based search
|
||||||
|
nextStep(origin, (mergeStageWithCalc) ? depth : depth - 1, rCur, merged);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,72 @@
|
||||||
|
#include "nnet/Pass/Rule5RangeRelaxation.h"
|
||||||
|
#include "nnet/Visitor/RangeRelaxFunctor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule5RangeRelaxation::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
rule5RangeRelaxation(origin, depth, rCur);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Rule5RangeRelaxation::rule5RangeRelaxation(Formula &origin, int depth,
|
||||||
|
Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
if (cur->hasPaddings()) {
|
||||||
|
// string msg = "====== END rule5RangeRelaxation: Paddings exist \n";
|
||||||
|
// dbg(msg);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer meaningful calculation range
|
||||||
|
RangeRelaxFunctor rangeRexlaxtionFunctor{cur};
|
||||||
|
RangeMap rangeMap = rangeRexlaxtionFunctor(cur);
|
||||||
|
auto relaxedCur = make_ref<RangeOpNode>(*cur);
|
||||||
|
bool isRelaxed = false;
|
||||||
|
vector<int> paddings;
|
||||||
|
// check whether narrow the calculation range
|
||||||
|
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
|
||||||
|
const auto &[iter, iterRange] =
|
||||||
|
cur->getVarRange(IterationType::Loop, i);
|
||||||
|
if (auto it = rangeMap.find(iter); it != rangeMap.end()) {
|
||||||
|
// intersection of validRange and iterRange is necessary computation
|
||||||
|
// TODO: it is redundant with RangeRelaxFunctor::intersectRangeMaps.
|
||||||
|
// An independent Range class might be necessary.
|
||||||
|
const Range &validRange = it->second;
|
||||||
|
Range relaxedRange{max(iterRange.first, validRange.first),
|
||||||
|
min(iterRange.second, validRange.second)};
|
||||||
|
if (relaxedRange != iterRange) {
|
||||||
|
isRelaxed = true;
|
||||||
|
relaxedCur->setVarRange(IterationType::Loop, i,
|
||||||
|
{iter, relaxedRange});
|
||||||
|
paddings.emplace_back(
|
||||||
|
max(relaxedRange.first - iterRange.first,
|
||||||
|
iterRange.second - relaxedRange.second));
|
||||||
|
} else
|
||||||
|
paddings.emplace_back(0);
|
||||||
|
} else
|
||||||
|
paddings.emplace_back(0);
|
||||||
|
}
|
||||||
|
relaxedCur->setPaddings(paddings);
|
||||||
|
if (!isRelaxed) {
|
||||||
|
// string msg = "====== END rule5RangeRelaxation: Relaxation not
|
||||||
|
// found\n"; dbg(msg);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// next searching step
|
||||||
|
string detailedMsg;
|
||||||
|
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
|
||||||
|
const auto &[v, a] = cur->getVarRange(IterationType::Loop, i);
|
||||||
|
const auto &[_, b] = relaxedCur->getVarRange(IterationType::Loop, i);
|
||||||
|
if (a != b) {
|
||||||
|
detailedMsg += v->getName();
|
||||||
|
detailedMsg +=
|
||||||
|
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
|
||||||
|
detailedMsg +=
|
||||||
|
" (" + to_string(b.first) + "," + to_string(b.second) + "),";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextStep(origin, depth, rCur, relaxedCur, detailedMsg);
|
||||||
|
return relaxedCur;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,57 @@
|
||||||
|
#include "nnet/Pass/Rule6KenerlMatching.h"
|
||||||
|
#include "nnet/Visitor/InputVisitor.h"
|
||||||
|
#include "nnet/Visitor/PatternMatcher.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule6KenerlMatching::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
// Build wrapper stages for enforce axis starts from 0
|
||||||
|
PatternMatcher patternMatcher(derivator, cur);
|
||||||
|
cur = patternMatcher.getOffsetCur();
|
||||||
|
|
||||||
|
// Match matchable routines
|
||||||
|
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
|
||||||
|
auto targetOp = idToRoutineType(i);
|
||||||
|
// During guided search, only check the target OP
|
||||||
|
if (derivator.getTargetOp() != RoutineType::NoneType &&
|
||||||
|
derivator.getTargetOp() != targetOp)
|
||||||
|
continue;
|
||||||
|
auto replaces =
|
||||||
|
patternMatcher.matchWithPattern(cur, getPattern(targetOp));
|
||||||
|
replaces = patternMatcher.applyWrapper(replaces);
|
||||||
|
for (auto newCur : replaces)
|
||||||
|
nextStep(origin, depth, rCur, newCur);
|
||||||
|
}
|
||||||
|
{ // Match element-wise OP
|
||||||
|
auto replaces = matchElementWise(cur);
|
||||||
|
if (!replaces.empty())
|
||||||
|
dbg(rCur);
|
||||||
|
for (auto newCur : replaces)
|
||||||
|
nextStep(origin, depth, rCur, newCur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VecExpr Rule6KenerlMatching::matchElementWise(const RangeOp &rangeOp) {
|
||||||
|
// If the stage is compute bound, then do not convert it.
|
||||||
|
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
|
||||||
|
int64_t inputSize = rangeOp->getInputSize(rangeOp);
|
||||||
|
if (double(flops) / (inputSize + outputSize) > 3)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
vector<int> newShape;
|
||||||
|
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||||
|
if (range.first != 0) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
newShape.emplace_back(range.second - range.first);
|
||||||
|
}
|
||||||
|
const auto &inputs = InputVisitor().getInputs(rangeOp);
|
||||||
|
auto source =
|
||||||
|
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
|
||||||
|
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
|
||||||
|
return {newTensor};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,78 @@
|
||||||
|
#include "nnet/Pass/Rule7DLT.h"
|
||||||
|
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||||
|
#include "nnet/dlt.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule7DLT::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
if (!cur)
|
||||||
|
return;
|
||||||
|
auto op = as<BinaryOpNode>(cur->getSummand());
|
||||||
|
if (!op)
|
||||||
|
return;
|
||||||
|
auto subs = {op->getLhs(), op->getRhs()};
|
||||||
|
for (auto subExpr : subs) {
|
||||||
|
auto sub = as<SubscriptNode>(subExpr);
|
||||||
|
if (!sub)
|
||||||
|
continue;
|
||||||
|
auto tensor = as<TensorNode>(sub->getObject());
|
||||||
|
if (!tensor)
|
||||||
|
continue;
|
||||||
|
// // HACK for G2BMM
|
||||||
|
// if (tensor->getDims() != 3)
|
||||||
|
// continue;
|
||||||
|
for (const auto factor : getFactors()) {
|
||||||
|
for (int targetDim = 0; targetDim < tensor->getDims();
|
||||||
|
++targetDim) {
|
||||||
|
if (tensor->getShape(targetDim) % factor)
|
||||||
|
continue;
|
||||||
|
// Debug hint for G2BMM
|
||||||
|
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
|
||||||
|
if (tensor->getShape(targetDim) != 10000)
|
||||||
|
continue;
|
||||||
|
assert(targetDim == 1);
|
||||||
|
}
|
||||||
|
DLT dlt;
|
||||||
|
dlt.split(targetDim, factor);
|
||||||
|
vector<int> newOrder(tensor->getDims() + 1);
|
||||||
|
for (int i = 0; i < tensor->getDims() + 1; ++i)
|
||||||
|
newOrder[i] = i;
|
||||||
|
newOrder[targetDim]++;
|
||||||
|
newOrder[targetDim + 1]--;
|
||||||
|
dlt.reorder(newOrder);
|
||||||
|
dlt.merge(targetDim, targetDim + 1);
|
||||||
|
if (auto opt = dlt.apply(cur, sub, newTensorName())) {
|
||||||
|
Expr newSummand = ReplaceNodeMutator().replace(
|
||||||
|
cur->getSummand(), sub.get(), *opt);
|
||||||
|
auto newCur = buildDLTSingleRangeOp(cur, newSummand);
|
||||||
|
|
||||||
|
// next searching step
|
||||||
|
string msg = "====== END rule7DLT\n";
|
||||||
|
dbg(msg);
|
||||||
|
nextStep(origin, depth, rCur, newCur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Rule7DLT::buildDLTSingleRangeOp(const RangeOp &original,
|
||||||
|
const Expr &newSummand) {
|
||||||
|
auto rangeOp = make_ref<RangeOpNode>(*original);
|
||||||
|
rangeOp->setSummand(newSummand);
|
||||||
|
return rangeOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> Rule7DLT::getFactors() {
|
||||||
|
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
|
||||||
|
return {4};
|
||||||
|
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
|
||||||
|
return {3, 4};
|
||||||
|
} else {
|
||||||
|
nnet_unimplemented_halt();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,317 @@
|
||||||
|
#include "nnet/Pass/Rule8GuidedDLT.h"
|
||||||
|
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
static int bitCount(unsigned int n) {
|
||||||
|
int count = 0;
|
||||||
|
while (n != 0) {
|
||||||
|
n = n & (n - 1);
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int bitPosition(unsigned int n) {
|
||||||
|
assert(bitCount(n) == 1);
|
||||||
|
int ret = 0;
|
||||||
|
for (n >>= 1; n; n >>= 1)
|
||||||
|
++ret;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Rule8GuidedDLT::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
guidedDLT(origin, depth, rCur);
|
||||||
|
}
|
||||||
|
|
||||||
|
VecExpr Rule8GuidedDLT::guidedDLT(Formula &origin, int depth, Expr &rCur,
|
||||||
|
bool debug) {
|
||||||
|
string detailedMsg;
|
||||||
|
VecExpr ret;
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
// check cur satisfies T1[A]*T2[B]
|
||||||
|
|
||||||
|
if (!statisfyGuidedDLT(cur))
|
||||||
|
return ret;
|
||||||
|
|
||||||
|
IteratorTable exprIT;
|
||||||
|
if (!exprIT.analyzeExpr(cur))
|
||||||
|
return ret;
|
||||||
|
exprIT.buildTableWithDefaultMap();
|
||||||
|
|
||||||
|
bool setTargetOpHere = false;
|
||||||
|
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
|
||||||
|
// if not correctly unset this variable
|
||||||
|
assert(setTargetOpHere == false);
|
||||||
|
// If the guide direction is set
|
||||||
|
if (derivator.getTargetOp() != RoutineType::NoneType &&
|
||||||
|
idToRoutineType(i) != derivator.getTargetOp())
|
||||||
|
continue;
|
||||||
|
// Warning: no continue befor unset the targetOp
|
||||||
|
if (derivator.getTargetOp() == RoutineType::NoneType) {
|
||||||
|
setTargetOpHere = true;
|
||||||
|
derivator.setTargetOp(idToRoutineType(i));
|
||||||
|
}
|
||||||
|
const Pattern &pattern = getPattern(derivator.getTargetOp());
|
||||||
|
auto mismatches = exprIT.matchPatternIT(pattern);
|
||||||
|
// Pruning less possible results
|
||||||
|
|
||||||
|
// std::cout << "mismatches= " << mismatches.size()
|
||||||
|
// << "; setTargetOpHere: " << setTargetOpHere << "; ";
|
||||||
|
// std::cout << "TargetOp = " <<
|
||||||
|
// static_cast<int>(derivator.getTargetOp())
|
||||||
|
// << "; mismatches : ";
|
||||||
|
// for (const auto i : mismatches)
|
||||||
|
// std::cout << static_cast<int>(i.type) << " ";
|
||||||
|
// std::cout << endl;
|
||||||
|
if (mismatches.size() == 0) {
|
||||||
|
derivator.setSearchState(2);
|
||||||
|
nextStep(origin, depth, rCur, rCur);
|
||||||
|
derivator.setSearchState(1);
|
||||||
|
}
|
||||||
|
if (mismatches.size() > 0 && mismatches.size() <= 2) {
|
||||||
|
for (const auto &mismatch : mismatches) {
|
||||||
|
Expr newCur;
|
||||||
|
if (mismatch.type == MismatchType::MoreVar) {
|
||||||
|
newCur = guidedDLTMoreVar2(cur, mismatch, exprIT, pattern);
|
||||||
|
detailedMsg += "guidedDLTMoreVar2 ";
|
||||||
|
} else if (mismatch.type == MismatchType::DLMismatch ||
|
||||||
|
mismatch.type == MismatchType::OutputDLMismatch) {
|
||||||
|
if (mismatches.size() > 1) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
newCur =
|
||||||
|
guidedDLTDLMismatch(cur, mismatch, exprIT, pattern);
|
||||||
|
detailedMsg += "guidedDLTDLMismatch ";
|
||||||
|
}
|
||||||
|
// std::cout << "newCur= "
|
||||||
|
// << ((newCur == nullptr) ? "Nullptr"
|
||||||
|
// : newCur->toReadable())
|
||||||
|
// << endl;
|
||||||
|
if (!newCur)
|
||||||
|
continue;
|
||||||
|
if (debug)
|
||||||
|
ret.emplace_back(newCur);
|
||||||
|
// next searching step
|
||||||
|
detailedMsg = "Toward " +
|
||||||
|
getPatternName(derivator.getTargetOp()) + ". " +
|
||||||
|
detailedMsg;
|
||||||
|
nextStep(origin, depth, rCur, newCur, detailedMsg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Unset targetOp
|
||||||
|
if (setTargetOpHere) {
|
||||||
|
derivator.setTargetOp(RoutineType::NoneType);
|
||||||
|
setTargetOpHere = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Rule8GuidedDLT::guidedDLTDLMismatch(
|
||||||
|
const RangeOp &cur, const Mismatch &mismatch,
|
||||||
|
[[maybe_unused]] const IteratorTable &exprIT, const Pattern &pattern) {
|
||||||
|
assert(mismatch.type == MismatchType::DLMismatch ||
|
||||||
|
mismatch.type == MismatchType::OutputDLMismatch);
|
||||||
|
// Currently only deal with ouput DLT
|
||||||
|
if (mismatch.bitmap != pattern.getNumInputs()) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
vector<VarRangePair> newVarRanges;
|
||||||
|
for (const auto &[var, _] : pattern.getRangeOp()->getLoopVarRanges()) {
|
||||||
|
const auto &iterInExpr = mismatch.mappingIter_r.at(var);
|
||||||
|
newVarRanges.emplace_back(cur->getVarRange(iterInExpr));
|
||||||
|
}
|
||||||
|
auto inner = make_ref<RangeOpNode>(*cur);
|
||||||
|
inner->setLoopIterator(newVarRanges);
|
||||||
|
auto subscriptedInner =
|
||||||
|
ReplaceKit::buildSubscirptForLoopVarReplace(inner, {});
|
||||||
|
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||||
|
return outer;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Rule8GuidedDLT::statisfyGuidedDLT(RangeOp cur) const {
|
||||||
|
auto mul = as<BinaryOpNode>(cur->getSummand());
|
||||||
|
if (!mul)
|
||||||
|
return false;
|
||||||
|
if (mul->getOpType() != OpType::Mul)
|
||||||
|
return false;
|
||||||
|
return as<SubscriptNode>(mul->getLhs()) && as<SubscriptNode>(mul->getRhs());
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Rule8GuidedDLT::guidedDLTMoreVar2(const RangeOp &cur,
|
||||||
|
const Mismatch &mismatch,
|
||||||
|
const IteratorTable &exprIT,
|
||||||
|
const Pattern &pattern) {
|
||||||
|
int bitmap = mismatch.bitmap;
|
||||||
|
const auto &mergedItersDefaultOrder = exprIT.getPosTable(bitmap);
|
||||||
|
|
||||||
|
// Assure vars only appear in one input tensor
|
||||||
|
int bitmapOfInputs = bitmap & ((1 << exprIT.getNumInputs()) - 1);
|
||||||
|
if (bitCount(bitmapOfInputs) > 1)
|
||||||
|
return nullptr;
|
||||||
|
if (pattern.getPosTable(bitmap).size() != 1) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (mergedItersDefaultOrder.size() < 1)
|
||||||
|
return nullptr;
|
||||||
|
int tensorID = bitPosition(bitmapOfInputs);
|
||||||
|
if (!checkElementsHaveOnlyOneAccessIteratorSet(exprIT, tensorID))
|
||||||
|
return nullptr;
|
||||||
|
vector<Var> oldVars; // i_1, ...
|
||||||
|
vector<Var> newVars; // j_1, ...
|
||||||
|
VecExpr psis; // i_1=\psi_1(j_1, ...)
|
||||||
|
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
|
||||||
|
vector<VarRangePair> newVarRanges;
|
||||||
|
|
||||||
|
auto originalTensor = exprIT.getTensor(tensorID);
|
||||||
|
auto originalSub = exprIT.getSubscript(tensorID);
|
||||||
|
vector<bool> mergedDims(originalTensor->getDims());
|
||||||
|
|
||||||
|
// Heuristic: merge iters according to their appearance positions
|
||||||
|
std::multimap<int, Var> sortedMergedIters;
|
||||||
|
for (const auto &iter : mergedItersDefaultOrder) {
|
||||||
|
vector<int> dims = exprIT.getIterDimInTensor(tensorID, iter);
|
||||||
|
assert(dims.size() == 1);
|
||||||
|
sortedMergedIters.emplace(dims[0], iter);
|
||||||
|
}
|
||||||
|
vector<Var> mergedIters; // decides the order of fused dims
|
||||||
|
for (const auto &[_, v] : sortedMergedIters)
|
||||||
|
mergedIters.emplace_back(v);
|
||||||
|
|
||||||
|
// Add the merged iterators
|
||||||
|
const auto newVar = getNewVar();
|
||||||
|
newVars.emplace_back(newVar);
|
||||||
|
int newRange = 1;
|
||||||
|
for (const auto &iter : mergedIters) {
|
||||||
|
oldVars.emplace_back(iter);
|
||||||
|
auto range = cur->getRange(iter);
|
||||||
|
newRange *= (range.second - range.first);
|
||||||
|
// if (range.first == 0)
|
||||||
|
// nnet_unimplemented_halt();
|
||||||
|
}
|
||||||
|
newVarRanges.emplace_back(newVar, Range{0, newRange});
|
||||||
|
// Add psis for each old iterator
|
||||||
|
int remainingRange = newRange;
|
||||||
|
Expr phi = nullptr;
|
||||||
|
for (const auto &iter : mergedIters) {
|
||||||
|
auto oldVar = iter;
|
||||||
|
auto range = cur->getRange(iter);
|
||||||
|
int len = (range.second - range.first);
|
||||||
|
remainingRange /= len;
|
||||||
|
Expr psi = newVar;
|
||||||
|
if (remainingRange > 1)
|
||||||
|
psi = psi / remainingRange;
|
||||||
|
if (newRange > remainingRange * len)
|
||||||
|
psi = psi % len;
|
||||||
|
int start = cur->getRange(iter).first;
|
||||||
|
if (start != 0)
|
||||||
|
psi = psi + start;
|
||||||
|
psis.emplace_back(psi);
|
||||||
|
phi = phi + remainingRange * (oldVar - start);
|
||||||
|
}
|
||||||
|
Replace replace{.iteratorType = IterationType::Loop,
|
||||||
|
.oldIters = oldVars,
|
||||||
|
.newIters = newVars,
|
||||||
|
.phis = VecExpr{phi},
|
||||||
|
.psis = psis,
|
||||||
|
.newVarRanges = newVarRanges};
|
||||||
|
// HACK: decide the rebuild data shape order
|
||||||
|
// TODO: get a partial iter mapping and permutate them?
|
||||||
|
vector<Var> tensorDimAxes{newVars};
|
||||||
|
vector<int> newShape;
|
||||||
|
for (const auto &[var, range] : newVarRanges)
|
||||||
|
newShape.emplace_back(range.second - range.first);
|
||||||
|
for (int row = 0; row < exprIT.getNumRows(); ++row) {
|
||||||
|
// Deal with other dimensions of the current tensor
|
||||||
|
if (row == bitmap || ((row & (1 << tensorID)) == 0))
|
||||||
|
continue;
|
||||||
|
using StrideIter = tuple<int, int, Iterator>;
|
||||||
|
vector<StrideIter> strideIters;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < exprIT.getPosTable(row).size(); ++i) {
|
||||||
|
const auto &iter = exprIT.getPosTable(row)[i];
|
||||||
|
const Range range = cur->getRange(iter);
|
||||||
|
const int len = range.second - range.first;
|
||||||
|
|
||||||
|
// HACK Sort according to original stride. (keep original order)
|
||||||
|
strideIters.emplace_back(-exprIT.getStridesInTensor(iter, tensorID),
|
||||||
|
len, iter);
|
||||||
|
|
||||||
|
// // HACK for conv
|
||||||
|
// if (iter == "n")
|
||||||
|
// strideIters.emplace_back(2, len, iter);
|
||||||
|
// else if (iter == "c")
|
||||||
|
// strideIters.emplace_back(1, len, iter);
|
||||||
|
// else
|
||||||
|
// strideIters.emplace_back(0, len, iter);
|
||||||
|
}
|
||||||
|
// HACK: Assure the order of iterators
|
||||||
|
std::sort(strideIters.begin(), strideIters.end(),
|
||||||
|
ref_value_less<StrideIter>);
|
||||||
|
for (const auto &[_, len, oldIter] : strideIters) {
|
||||||
|
const auto &oldVar = oldIter;
|
||||||
|
tensorDimAxes.emplace_back(oldVar);
|
||||||
|
newShape.emplace_back(len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// build DLT source
|
||||||
|
const auto sourceExpr =
|
||||||
|
buildGuidedDLTSource(originalSub, replace, tensorDimAxes, newShape);
|
||||||
|
const auto sourceRoutine = make_ref<ElementWiseNode>(
|
||||||
|
sourceExpr, vector<Tensor>{originalTensor}, newShape);
|
||||||
|
// build stage connections
|
||||||
|
const auto newTensor =
|
||||||
|
makeTensor(newTensorName(), newShape, {}, sourceRoutine);
|
||||||
|
const auto &newSub = makeSubscript(
|
||||||
|
newTensor, VecExpr(tensorDimAxes.begin(), tensorDimAxes.end()));
|
||||||
|
// TODO [1124]: get variable mapping and reorder L according to it
|
||||||
|
// dbg(cur, originalSub, newSub, newVarRanges, replace.toReadable(),
|
||||||
|
// tensorDimAxes, newShape);
|
||||||
|
|
||||||
|
// Replace the entire subscript(A[xxxxx,xxx]) in the summand
|
||||||
|
Expr newSummand = ReplaceNodeMutator().replace(cur->getSummand(),
|
||||||
|
originalSub.get(), newSub);
|
||||||
|
auto inner = ReplaceKit::replaceRangeOpIterator(cur, replace, newSummand);
|
||||||
|
auto subscriptedInner =
|
||||||
|
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
|
||||||
|
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||||
|
return outer;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Rule8GuidedDLT::checkElementsHaveOnlyOneAccessIteratorSet(
|
||||||
|
const IteratorTable &exprIT, int tensorID) {
|
||||||
|
const auto &strideInDim = exprIT.getStrideInDim();
|
||||||
|
for (const auto &strideForOneDim : strideInDim[tensorID]) {
|
||||||
|
vector<pair<int, int>> strideLengthPairs;
|
||||||
|
for (const auto &[iter, s] : strideForOneDim) {
|
||||||
|
const auto &range = exprIT.getRangeOp()->getRange(iter);
|
||||||
|
strideLengthPairs.emplace_back(s, range.second - range.first);
|
||||||
|
}
|
||||||
|
std::sort(strideLengthPairs.begin(), strideLengthPairs.end());
|
||||||
|
for (size_t i = 0; i < strideLengthPairs.size() - 1; ++i) {
|
||||||
|
const auto &[stride, length] = strideLengthPairs[i];
|
||||||
|
if (stride * length > strideLengthPairs[i + 1].first)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Rule8GuidedDLT::buildGuidedDLTSource(const Subscript &originalSub,
|
||||||
|
Replace replace,
|
||||||
|
vector<Var> tensorDimAxes,
|
||||||
|
vector<int> newShape) {
|
||||||
|
Expr newSub = ReplaceKit::replaceMultipleExprs(
|
||||||
|
originalSub, replace.oldIters, replace.psis, true);
|
||||||
|
vector<VarRangePair> loopVarRangePairs;
|
||||||
|
for (size_t i = 0; i < tensorDimAxes.size(); ++i)
|
||||||
|
loopVarRangePairs.emplace_back(tensorDimAxes[i], pair(0, newShape[i]));
|
||||||
|
return makeRangeOperator(loopVarRangePairs, {}, newSub);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,54 @@
|
||||||
|
#include "nnet/Pass/Rule90TwoStageElementWise.h"
|
||||||
|
#include "nnet/Visitor/InputVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule90TwoStageElementWise::transform(Formula &origin, int depth,
|
||||||
|
Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
{ // Match element-wise OP
|
||||||
|
auto replaces = matchTwoStageElementWise(cur);
|
||||||
|
// if (!replaces.empty())
|
||||||
|
// dbg(rCur);
|
||||||
|
// dbg(replaces);
|
||||||
|
for (auto newCur : replaces)
|
||||||
|
nextStep(origin, depth, rCur, newCur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VecExpr
|
||||||
|
Rule90TwoStageElementWise::matchTwoStageElementWise(const RangeOp &rangeOp) {
|
||||||
|
// If the stage is compute bound, then do not convert it.
|
||||||
|
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
|
||||||
|
int64_t inputSize = rangeOp->getInputSize(rangeOp);
|
||||||
|
if (double(flops) / (inputSize + outputSize) > 3)
|
||||||
|
return {};
|
||||||
|
auto outerSub = as<SubscriptNode>(rangeOp->getSummand());
|
||||||
|
if (!outerSub)
|
||||||
|
return {};
|
||||||
|
auto innerRangeOp = as<RangeOpNode>(outerSub->getObject());
|
||||||
|
if (!innerRangeOp)
|
||||||
|
return {};
|
||||||
|
auto innerSub = as<SubscriptNode>(innerRangeOp->getSummand());
|
||||||
|
if (!innerSub)
|
||||||
|
return {};
|
||||||
|
auto innerTensor = as<TensorNode>(innerSub->getObject());
|
||||||
|
if (!innerTensor)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
vector<int> newShape;
|
||||||
|
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||||
|
if (range.first != 0) {
|
||||||
|
nnet_unimplemented_continue();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
newShape.emplace_back(range.second - range.first);
|
||||||
|
}
|
||||||
|
const auto &inputs = InputVisitor().getInputs(rangeOp);
|
||||||
|
auto source =
|
||||||
|
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
|
||||||
|
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
|
||||||
|
return {newTensor};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,11 @@
|
||||||
|
#include "nnet/Pass/Rule91MergeStagesWithSum.h"
|
||||||
|
#include "nnet/Pass/Rule4StageMerging.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule91MergeStagesWithSum::transform(Formula &origin, int depth,
|
||||||
|
Expr &rCur) {
|
||||||
|
Rule4StageMerging(derivator).rule4StageMerging(origin, depth, rCur, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "nnet/Pass/Rule9RangeMagnify.h"
|
||||||
|
#include "nnet/Visitor/RangeMagnifyVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void Rule9RangeMagnify::transform(Formula &origin, int depth, Expr &rCur) {
|
||||||
|
auto cur = as<RangeOpNode>(rCur);
|
||||||
|
if (cur->hasPaddings()) {
|
||||||
|
// string msg = "====== END rule9RangeMagnify: Paddings exist \n";
|
||||||
|
// dbg(msg);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// HACK for conv5x5
|
||||||
|
vector<VarRangePair> newSumVarRanges;
|
||||||
|
for (const auto &[var, range] : cur->getSumVarRanges()) {
|
||||||
|
if (range.first == 0 && range.second == 5) {
|
||||||
|
newSumVarRanges.emplace_back(
|
||||||
|
var, Range{range.first, (range.second + 2) / 3 * 3});
|
||||||
|
} else
|
||||||
|
newSumVarRanges.emplace_back(var, range);
|
||||||
|
}
|
||||||
|
if (newSumVarRanges.empty())
|
||||||
|
return;
|
||||||
|
auto magnifiedCur = RangeMagnifyVisitor().magnify(cur, newSumVarRanges);
|
||||||
|
if (!magnifiedCur)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// next searching step
|
||||||
|
string msg = "relax iterating ranges ";
|
||||||
|
for (size_t i = 0; i < cur->getSumVarRanges().size(); ++i) {
|
||||||
|
const auto &[v1, a] = cur->getVarRange(IterationType::Sum, i);
|
||||||
|
const auto &[v2, b] = magnifiedCur->getVarRange(IterationType::Sum, i);
|
||||||
|
assert(v1->getName() == v2->getName());
|
||||||
|
if (a != b) {
|
||||||
|
msg += v1->getName();
|
||||||
|
msg +=
|
||||||
|
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
|
||||||
|
msg += " (" + to_string(b.first) + "," + to_string(b.second) + "),";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextStep(origin, depth, rCur, magnifiedCur, msg);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,165 @@
|
||||||
|
#include "nnet/Visitor/AsTVMVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
std::string AsTVMVisitor::visit_(const Constant &c) {
|
||||||
|
return std::to_string(c->getValue());
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::visit_(const BinaryOp &c) {
|
||||||
|
switch (c->getOpType()) {
|
||||||
|
case OpType::Add:
|
||||||
|
return "(" + dispatch(c->getLhs()) + " + " + dispatch(c->getRhs()) +
|
||||||
|
")";
|
||||||
|
case OpType::Sub:
|
||||||
|
return "(" + dispatch(c->getLhs()) + " - " + dispatch(c->getRhs()) +
|
||||||
|
")";
|
||||||
|
case OpType::Mul:
|
||||||
|
return "(" + dispatch(c->getLhs()) + " * " + dispatch(c->getRhs()) +
|
||||||
|
")";
|
||||||
|
case OpType::Div:
|
||||||
|
return "(" + dispatch(c->getLhs()) + " // " + dispatch(c->getRhs()) +
|
||||||
|
")";
|
||||||
|
case OpType::Mod:
|
||||||
|
return "(" + dispatch(c->getLhs()) + " % " + dispatch(c->getRhs()) +
|
||||||
|
")";
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::visit_(const Func &c) {
|
||||||
|
switch (c->getFuncType()) {
|
||||||
|
case FuncType::Relu:
|
||||||
|
// TODO: Deduce the dtype
|
||||||
|
return "te.max(" + dispatch(c->getObject()) +
|
||||||
|
", tvm.tir.const(0, 'float32'))";
|
||||||
|
case FuncType::Tanh:
|
||||||
|
return "te.tanh(" + dispatch(c->getObject()) + ")";
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::visit_(const RangeOp &c) {
|
||||||
|
auto outerStage = curStage;
|
||||||
|
curStage = nStage++;
|
||||||
|
|
||||||
|
std::string stmt;
|
||||||
|
std::string stageName = "s" + std::to_string(curStage);
|
||||||
|
std::vector<std::string> reduceVars;
|
||||||
|
for (auto &&[var, range] : c->getSumVarRanges()) {
|
||||||
|
std::string varName = stageName + "_" + var->getName();
|
||||||
|
stmt += varName + " = " + "te.reduce_axis((" +
|
||||||
|
std::to_string(range.first) + ", " +
|
||||||
|
std::to_string(range.second) + "), name=\"" + varName + "\")\n";
|
||||||
|
reduceVars.emplace_back(varName);
|
||||||
|
pythonVars.emplace_back(varName);
|
||||||
|
}
|
||||||
|
std::vector<int> shape;
|
||||||
|
stmt += stageName + " = te.compute((";
|
||||||
|
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
|
||||||
|
auto &&[var, range] = c->getLoopVarRanges()[i];
|
||||||
|
std::string varName = stageName + "_" + var->getName();
|
||||||
|
offset[varName] = -range.first + c->getPaddings(i);
|
||||||
|
auto len = range.second - range.first + 2 * c->getPaddings(i);
|
||||||
|
stmt += std::to_string(len) + ", ";
|
||||||
|
shape.emplace_back(len);
|
||||||
|
}
|
||||||
|
stmt += "), lambda ";
|
||||||
|
bool first = true;
|
||||||
|
for (auto &&[var, range] : c->getLoopVarRanges()) {
|
||||||
|
std::string varName = stageName + "_" + var->getName();
|
||||||
|
stmt += (first ? "" : ", ") + varName;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
std::string summand = dispatch(c->getSummand());
|
||||||
|
if (!reduceVars.empty()) {
|
||||||
|
summand = "te.sum(" + summand + ", axis=(";
|
||||||
|
for (auto &&var : reduceVars) {
|
||||||
|
summand += var + ", ";
|
||||||
|
}
|
||||||
|
summand += "))";
|
||||||
|
}
|
||||||
|
if (c->hasPaddings()) {
|
||||||
|
std::string guard = "tir.if_then_else(tir.all(";
|
||||||
|
bool first = true;
|
||||||
|
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
|
||||||
|
auto &&[var, range] = c->getLoopVarRanges()[i];
|
||||||
|
std::string varName = stageName + "_" + var->getName();
|
||||||
|
if (auto pad = c->getPaddings(i); pad > 0) {
|
||||||
|
guard += (first ? "" : ", ") + varName +
|
||||||
|
" >= " + std::to_string(range.first) + ", " + varName +
|
||||||
|
" < " + std::to_string(range.second);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: Deduce the dtype
|
||||||
|
guard += "), " + summand + ", tvm.tir.const(0.0, \"float32\"))";
|
||||||
|
summand = guard;
|
||||||
|
}
|
||||||
|
stmt += ": " + summand + ")";
|
||||||
|
stmts += stmt + "\n";
|
||||||
|
|
||||||
|
pythonVars.emplace_back(stageName);
|
||||||
|
output = stageName;
|
||||||
|
outputShape = std::move(shape);
|
||||||
|
curStage = outerStage;
|
||||||
|
return stageName;
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::visit_(const Subscript &c) {
|
||||||
|
std::string str = dispatch(c->getObject()) + "[";
|
||||||
|
for (size_t i = 0, n = c->getIndex().size(); i < n; i++) {
|
||||||
|
const auto &idx = c->getIndex()[i];
|
||||||
|
str += (i == 0 ? "" : ", ") + dispatch(idx);
|
||||||
|
if (c->getObject()->getType() == NodeType::RangeOpNodeType) {
|
||||||
|
auto rangeOp = as<RangeOpNode>(c->getObject());
|
||||||
|
str += " - " +
|
||||||
|
std::to_string(rangeOp->getLoopVarRanges()[i].second.first -
|
||||||
|
rangeOp->getPaddings(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
str += "]";
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::visit_(const Var &c) {
|
||||||
|
std::string stageName = "s" + std::to_string(curStage);
|
||||||
|
std::string varName = stageName + "_" + c->getName();
|
||||||
|
if (offset.count(varName)) {
|
||||||
|
return "(" + varName + " - " + std::to_string(offset.at(varName)) + ")";
|
||||||
|
} else {
|
||||||
|
return varName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::visit_(const Tensor &c) {
|
||||||
|
pythonVars.emplace_back(c->getName());
|
||||||
|
inputs.emplace_back(c->getName());
|
||||||
|
inputShapes.emplace_back(c->getShape());
|
||||||
|
std::string stmt = c->getName() + " = te.placeholder((";
|
||||||
|
for (auto &&dim : c->getShape()) {
|
||||||
|
stmt += std::to_string(dim) + ", ";
|
||||||
|
}
|
||||||
|
stmt += "), name='" + c->getName() + "')";
|
||||||
|
stmts += stmt + "\n";
|
||||||
|
return c->getName();
|
||||||
|
}
|
||||||
|
std::string AsTVMVisitor::getStmts() const {
|
||||||
|
std::string ret;
|
||||||
|
|
||||||
|
// Workaround because closure capturing does not work in an `exec`
|
||||||
|
// https://stackoverflow.com/questions/2749655/why-are-closures-broken-within-exec
|
||||||
|
ret += "global ";
|
||||||
|
bool first = true;
|
||||||
|
for (auto &&var : pythonVars) {
|
||||||
|
ret += (first ? "" : ", ") + var;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
ret += "\n";
|
||||||
|
|
||||||
|
ret += stmts;
|
||||||
|
ret += "ret = [" + output;
|
||||||
|
for (auto &&input : inputs) {
|
||||||
|
ret += ", " + input;
|
||||||
|
}
|
||||||
|
ret += "]\n";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,35 @@
|
||||||
|
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||||
|
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void CheckOOBVisitor::visit_(const Subscript &c) {
|
||||||
|
const auto &objectRanges = c->getObjectRangesWithPaddings();
|
||||||
|
for (size_t dim = 0; dim < c->getDims(); ++dim) {
|
||||||
|
SimplifyExprVisitor simplifier;
|
||||||
|
auto optional = simplifier.getExprRange(c->getIndex(dim), rangeOp);
|
||||||
|
if (!optional.has_value())
|
||||||
|
continue;
|
||||||
|
const Range &exprRange = *optional;
|
||||||
|
if (exprRange.first < objectRanges[dim].first ||
|
||||||
|
exprRange.second > objectRanges[dim].second) {
|
||||||
|
// dbg("OOB detected!", c, dim, exprRange, objectRanges[dim]);
|
||||||
|
// std::cout << "OOB detected! " << c->toReadable() << ", dim=" <<
|
||||||
|
// dim
|
||||||
|
// << ", Range=(" << exprRange.first << ", "
|
||||||
|
// << exprRange.second << "), objRange=("
|
||||||
|
// << objectRanges[dim].first << ", "
|
||||||
|
// << objectRanges[dim].second << ")." << std::endl;
|
||||||
|
detect = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckOOBVisitor::checkRangeOp(const RangeOp &_rangeOp) {
|
||||||
|
detect = false;
|
||||||
|
rangeOp = _rangeOp;
|
||||||
|
dispatch(rangeOp);
|
||||||
|
return detect;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,9 @@
|
||||||
|
#include "nnet/Visitor/CloneMutator.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
Expr CloneMutator::visit_(const Constant &c) { return c; }
|
||||||
|
Expr CloneMutator::visit_(const Var &c) { return c; }
|
||||||
|
Expr CloneMutator::visit_(const Tensor &c) { return c; }
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,34 @@
|
||||||
|
#include "nnet/Visitor/CompareMultiFormulasVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
bool CompareMultiFormulasVisitor::compare(const VecExpr &roots) {
|
||||||
|
if (roots.empty())
|
||||||
|
return false;
|
||||||
|
vector<RangeOp> rangeOps;
|
||||||
|
for (const auto &root : roots) {
|
||||||
|
if (auto rangeOp = as<RangeOpNode>(root))
|
||||||
|
rangeOps.emplace_back(rangeOp);
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto pattern = rangeOps[0];
|
||||||
|
for (auto rangeOp : rangeOps) {
|
||||||
|
if (pattern->getNumOutputDims() != rangeOp->getNumOutputDims()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < pattern->getNumOutputDims(); ++i)
|
||||||
|
if (pattern->getVarRange(0, i).second !=
|
||||||
|
rangeOp->getVarRange(0, i).second) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < pattern->getSumVarRanges().size(); ++i)
|
||||||
|
if (pattern->getVarRange(1, i).second !=
|
||||||
|
rangeOp->getVarRange(1, i).second) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
||||||
|
#include "nnet/Visitor/CountRoutineVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void CountRoutineVisitor::visit_(const Tensor &c) {
|
||||||
|
if (auto routine = c->getSource(); routine) {
|
||||||
|
cnts[routineTypeToId(routine->getType())]++;
|
||||||
|
}
|
||||||
|
ExprTreeVisitor::visit_(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> CountRoutineVisitor::count(const Expr &root) {
|
||||||
|
cnts = vector<int>(RoutineTypeCnt, 0);
|
||||||
|
dispatch(root);
|
||||||
|
return cnts;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CountRoutineVisitor::match(const Expr &root, int nMatmul, int nConv,
|
||||||
|
int nElement, int nSg2bmm,
|
||||||
|
int nLongformerGBMM) {
|
||||||
|
auto opCount = count(root);
|
||||||
|
bool ret = true;
|
||||||
|
if (opCount[routineTypeToId(RoutineType::MatmulNodeType)] != nMatmul)
|
||||||
|
ret = false;
|
||||||
|
if (opCount[routineTypeToId(RoutineType::ConvNodeType)] != nConv)
|
||||||
|
ret = false;
|
||||||
|
if (opCount[routineTypeToId(RoutineType::ElementWiseNodeType)] != nElement)
|
||||||
|
ret = false;
|
||||||
|
if (opCount.at(routineTypeToId(RoutineType::G2bmmNodeType)) != nSg2bmm)
|
||||||
|
ret = false;
|
||||||
|
if (!ret) {
|
||||||
|
auto target =
|
||||||
|
vector<int>{nMatmul, nConv, nSg2bmm, nLongformerGBMM, nElement};
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,58 @@
|
||||||
|
#include "nnet/Visitor/FullPrinterVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void FullPrinterVisitor::visit_(const Tensor &c) {
|
||||||
|
q.emplace_back(c->getName(), c->getSource(), c);
|
||||||
|
}
|
||||||
|
|
||||||
|
string FullPrinterVisitor::print(const Expr &root) {
|
||||||
|
q.clear();
|
||||||
|
std::ostringstream oss;
|
||||||
|
dispatch(root);
|
||||||
|
oss << "==> ROOT\n" << root->toReadable() << "\n";
|
||||||
|
for (size_t i = 0; i < q.size(); ++i) {
|
||||||
|
const auto &[name, routine, tensor] = q[i];
|
||||||
|
oss << "==> " << name << " : ";
|
||||||
|
if (routine) {
|
||||||
|
oss << routine->toReadable() << "\n";
|
||||||
|
if (routine->getExpr()) {
|
||||||
|
oss << routine->getExpr()->toReadable() << "\n";
|
||||||
|
} else
|
||||||
|
oss << "[INFO] Source is nullptr \n";
|
||||||
|
if (!routine->getInputs().empty()) {
|
||||||
|
for (const auto &tensor : routine->getInputs())
|
||||||
|
q.emplace_back(tensor->getName(), tensor->getSource(),
|
||||||
|
tensor);
|
||||||
|
} else if (routine->getExpr())
|
||||||
|
dispatch(routine->getExpr());
|
||||||
|
} else
|
||||||
|
oss << "Input Tensor " << tensor->toOutputShape() << "\n";
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
const vector<tuple<string, Routine, Tensor>> &
|
||||||
|
FullPrinterVisitor::traverse(const Expr &root) {
|
||||||
|
q.clear();
|
||||||
|
dispatch(root);
|
||||||
|
for (size_t i = 0; i < q.size(); ++i) {
|
||||||
|
const auto &[name, routine, tensor] = q[i];
|
||||||
|
if (routine) {
|
||||||
|
// Matmul after DLT do not modify expression, so inputs has a higher
|
||||||
|
// priority. Some OPs such as DLT have not implement source. Then
|
||||||
|
// use inputs
|
||||||
|
if (!routine->getInputs().empty()) {
|
||||||
|
for (const auto &tensor : routine->getInputs())
|
||||||
|
dispatch(tensor);
|
||||||
|
} else if (routine->getExpr()) {
|
||||||
|
dispatch(routine->getExpr());
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return q;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
|
@ -0,0 +1,9 @@
|
||||||
|
#include "nnet/Visitor/GetTensorsVisitor.h"
|
||||||
|
|
||||||
|
namespace nnet {
|
||||||
|
|
||||||
|
void GetTensorsVisitor::visit_(const Tensor &c) {
|
||||||
|
tensors.try_emplace(c->getName(), c);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace nnet
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue