Merge pull request #1 from InfiniTensor/init

Initialization
This commit is contained in:
Haojie Wang 2022-08-09 20:21:20 +08:00 committed by GitHub
commit eda41b06a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
270 changed files with 71286 additions and 0 deletions

90
.clang-format Normal file
View File

@ -0,0 +1,90 @@
---
Language: Cpp
# BasedOnStyle: LLVM
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: false
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: false
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Right
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...

311
.cmake-format.json Normal file
View File

@ -0,0 +1,311 @@
{
"_help_parse": "Options affecting listfile parsing",
"parse": {
"_help_additional_commands": [
"Specify structure for custom cmake functions"
],
"additional_commands": {
"foo": {
"flags": [
"BAR",
"BAZ"
],
"kwargs": {
"HEADERS": "*",
"SOURCES": "*",
"DEPENDS": "*"
}
}
},
"_help_override_spec": [
"Override configurations per-command where available"
],
"override_spec": {},
"_help_vartags": [
"Specify variable tags."
],
"vartags": [],
"_help_proptags": [
"Specify property tags."
],
"proptags": []
},
"_help_format": "Options affecting formatting.",
"format": {
"_help_disable": [
"Disable formatting entirely, making cmake-format a no-op"
],
"disable": false,
"_help_line_width": [
"How wide to allow formatted cmake files"
],
"line_width": 1500,
"_help_tab_size": [
"How many spaces to tab for indent"
],
"tab_size": 2,
"_help_use_tabchars": [
"If true, lines are indented using tab characters (utf-8",
"0x09) instead of <tab_size> space characters (utf-8 0x20).",
"In cases where the layout would require a fractional tab",
"character, the behavior of the fractional indentation is",
"governed by <fractional_tab_policy>"
],
"use_tabchars": false,
"_help_fractional_tab_policy": [
"If <use_tabchars> is True, then the value of this variable",
"indicates how fractional indentions are handled during",
"whitespace replacement. If set to 'use-space', fractional",
"indentation is left as spaces (utf-8 0x20). If set to",
"`round-up` fractional indentation is replaced with a single",
"tab character (utf-8 0x09) effectively shifting the column",
"to the next tabstop"
],
"fractional_tab_policy": "use-space",
"_help_max_subgroups_hwrap": [
"If an argument group contains more than this many sub-groups",
"(parg or kwarg groups) then force it to a vertical layout."
],
"max_subgroups_hwrap": 2,
"_help_max_pargs_hwrap": [
"If a positional argument group contains more than this many",
"arguments, then force it to a vertical layout."
],
"max_pargs_hwrap": 6,
"_help_max_rows_cmdline": [
"If a cmdline positional group consumes more than this many",
"lines without nesting, then invalidate the layout (and nest)"
],
"max_rows_cmdline": 2,
"_help_separate_ctrl_name_with_space": [
"If true, separate flow control names from their parentheses",
"with a space"
],
"separate_ctrl_name_with_space": false,
"_help_separate_fn_name_with_space": [
"If true, separate function names from parentheses with a",
"space"
],
"separate_fn_name_with_space": false,
"_help_dangle_parens": [
"If a statement is wrapped to more than one line, than dangle",
"the closing parenthesis on its own line."
],
"dangle_parens": false,
"_help_dangle_align": [
"If the trailing parenthesis must be 'dangled' on its on",
"line, then align it to this reference: `prefix`: the start",
"of the statement, `prefix-indent`: the start of the",
"statement, plus one indentation level, `child`: align to",
"the column of the arguments"
],
"dangle_align": "prefix",
"_help_min_prefix_chars": [
"If the statement spelling length (including space and",
"parenthesis) is smaller than this amount, then force reject",
"nested layouts."
],
"min_prefix_chars": 4,
"_help_max_prefix_chars": [
"If the statement spelling length (including space and",
"parenthesis) is larger than the tab width by more than this",
"amount, then force reject un-nested layouts."
],
"max_prefix_chars": 10,
"_help_max_lines_hwrap": [
"If a candidate layout is wrapped horizontally but it exceeds",
"this many lines, then reject the layout."
],
"max_lines_hwrap": 2,
"_help_line_ending": [
"What style line endings to use in the output."
],
"line_ending": "unix",
"_help_command_case": [
"Format command names consistently as 'lower' or 'upper' case"
],
"command_case": "canonical",
"_help_keyword_case": [
"Format keywords consistently as 'lower' or 'upper' case"
],
"keyword_case": "unchanged",
"_help_always_wrap": [
"A list of command names which should always be wrapped"
],
"always_wrap": [],
"_help_enable_sort": [
"If true, the argument lists which are known to be sortable",
"will be sorted lexicographicall"
],
"enable_sort": true,
"_help_autosort": [
"If true, the parsers may infer whether or not an argument",
"list is sortable (without annotation)."
],
"autosort": false,
"_help_require_valid_layout": [
"By default, if cmake-format cannot successfully fit",
"everything into the desired linewidth it will apply the",
"last, most aggressive attempt that it made. If this flag is",
"True, however, cmake-format will print error, exit with non-",
"zero status code, and write-out nothing"
],
"require_valid_layout": false,
"_help_layout_passes": [
"A dictionary mapping layout nodes to a list of wrap",
"decisions. See the documentation for more information."
],
"layout_passes": {}
},
"_help_markup": "Options affecting comment reflow and formatting.",
"markup": {
"_help_bullet_char": [
"What character to use for bulleted lists"
],
"bullet_char": "*",
"_help_enum_char": [
"What character to use as punctuation after numerals in an",
"enumerated list"
],
"enum_char": ".",
"_help_first_comment_is_literal": [
"If comment markup is enabled, don't reflow the first comment",
"block in each listfile. Use this to preserve formatting of",
"your copyright/license statements."
],
"first_comment_is_literal": false,
"_help_literal_comment_pattern": [
"If comment markup is enabled, don't reflow any comment block",
"which matches this (regex) pattern. Default is `None`",
"(disabled)."
],
"literal_comment_pattern": ".*",
"_help_fence_pattern": [
"Regular expression to match preformat fences in comments",
"default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``"
],
"fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$",
"_help_ruler_pattern": [
"Regular expression to match rulers in comments default=",
"``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``"
],
"ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$",
"_help_explicit_trailing_pattern": [
"If a comment line matches starts with this pattern then it",
"is explicitly a trailing comment for the preceding argument.",
"Default is '#<'"
],
"explicit_trailing_pattern": "#<",
"_help_hashruler_min_length": [
"If a comment line starts with at least this many consecutive",
"hash characters, then don't lstrip() them off. This allows",
"for lazy hash rulers where the first hash char is not",
"separated by space"
],
"hashruler_min_length": 10,
"_help_canonicalize_hashrulers": [
"If true, then insert a space between the first hash char and",
"remaining hash chars in a hash ruler, and normalize its",
"length to fill the column"
],
"canonicalize_hashrulers": true,
"_help_enable_markup": [
"enable comment markup parsing and reflow"
],
"enable_markup": true
},
"_help_lint": "Options affecting the linter",
"lint": {
"_help_disabled_codes": [
"a list of lint codes to disable"
],
"disabled_codes": [],
"_help_function_pattern": [
"regular expression pattern describing valid function names"
],
"function_pattern": "[0-9a-z_]+",
"_help_macro_pattern": [
"regular expression pattern describing valid macro names"
],
"macro_pattern": "[0-9A-Z_]+",
"_help_global_var_pattern": [
"regular expression pattern describing valid names for",
"variables with global (cache) scope"
],
"global_var_pattern": "[A-Z][0-9A-Z_]+",
"_help_internal_var_pattern": [
"regular expression pattern describing valid names for",
"variables with global scope (but internal semantic)"
],
"internal_var_pattern": "_[A-Z][0-9A-Z_]+",
"_help_local_var_pattern": [
"regular expression pattern describing valid names for",
"variables with local scope"
],
"local_var_pattern": "[a-z][a-z0-9_]+",
"_help_private_var_pattern": [
"regular expression pattern describing valid names for",
"privatedirectory variables"
],
"private_var_pattern": "_[0-9a-z_]+",
"_help_public_var_pattern": [
"regular expression pattern describing valid names for public",
"directory variables"
],
"public_var_pattern": "[A-Z][0-9A-Z_]+",
"_help_argument_var_pattern": [
"regular expression pattern describing valid names for",
"function/macro arguments and loop variables."
],
"argument_var_pattern": "[a-z][a-z0-9_]+",
"_help_keyword_pattern": [
"regular expression pattern describing valid names for",
"keywords used in functions or macros"
],
"keyword_pattern": "[A-Z][0-9A-Z_]+",
"_help_max_conditionals_custom_parser": [
"In the heuristic for C0201, how many conditionals to match",
"within a loop in before considering the loop a parser."
],
"max_conditionals_custom_parser": 2,
"_help_min_statement_spacing": [
"Require at least this many newlines between statements"
],
"min_statement_spacing": 1,
"_help_max_statement_spacing": [
"Require no more than this many newlines between statements"
],
"max_statement_spacing": 2,
"max_returns": 6,
"max_branches": 12,
"max_arguments": 5,
"max_localvars": 15,
"max_statements": 50
},
"_help_encode": "Options affecting file encoding",
"encode": {
"_help_emit_byteorder_mark": [
"If true, emit the unicode byte-order mark (BOM) at the start",
"of the file"
],
"emit_byteorder_mark": false,
"_help_input_encoding": [
"Specify the encoding of the input file. Defaults to utf-8"
],
"input_encoding": "utf-8",
"_help_output_encoding": [
"Specify the encoding of the output file. Defaults to utf-8.",
"Note that cmake only claims to support utf-8 so be careful",
"when using anything else"
],
"output_encoding": "utf-8"
},
"_help_misc": "Miscellaneous configurations options.",
"misc": {
"_help_per_command": [
"A dictionary containing any per-command configuration",
"overrides. Currently only `command_case` is supported."
],
"per_command": {}
}
}

View File

@ -0,0 +1,19 @@
name: clang-format Check
on: [pull_request]
jobs:
formatting-check:
name: Formatting Check
runs-on: ubuntu-latest
strategy:
matrix:
path:
- 'include'
- 'src'
- 'test'
steps:
- uses: actions/checkout@v2
- name: Run clang-format style check for C/C++/Protobuf programs.
uses: jidicula/clang-format-action@v4.8.0
with:
clang-format-version: '14'
check-path: ${{ matrix.path }}

5
.gitignore vendored
View File

@ -30,3 +30,8 @@
*.exe
*.out
*.app
build/
build_debug/
.vscode/

9
.gitmodules vendored Normal file
View File

@ -0,0 +1,9 @@
[submodule "3rd-party/pybind11"]
path = 3rd-party/pybind11
url = git@github.com:pybind/pybind11.git
[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"]
path = 3rd-party/nlohmann_json_cmake_fetchcontent
url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git
[submodule "3rd-party/googletest"]
path = 3rd-party/googletest
url = git@github.com:google/googletest.git

1
3rd-party/googletest vendored Submodule

@ -0,0 +1 @@
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929

@ -0,0 +1 @@
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756

1
3rd-party/pybind11 vendored Submodule

@ -0,0 +1 @@
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae

91
CMakeLists.txt Normal file
View File

@ -0,0 +1,91 @@
# TODO: check the minimum cmake version
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
include(CMakeDependentOption)
project(InfiniTensor C CXX)
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
option(BUILD_TEST "Build tests" ON)
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
find_package(
Python
COMPONENTS Interpreter Development
REQUIRED)
find_package(CUDA REQUIRED)
# OpenMP
find_package(OpenMP)
if(OpenMP_C_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
endif()
if(OpenMP_CXX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
include_directories(include)
# # Pybind11
# add_subdirectory(3rd-party/pybind11)
# include_directories(3rd-party/pybind11/include)
# nlohmann_json
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
if(BUILD_TEST)
set(BUILD_GMOCK
OFF
CACHE BOOL "Do not build gmock" FORCE)
set(INSTALL_GTEST
OFF
CACHE BOOL "Do not install gtest" FORCE)
add_subdirectory(3rd-party/googletest)
include_directories(3rd-party/googletest/googletest/include)
endif()
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
add_library(InfiniTensor SHARED ${SRC})
# Target
# cuda_add_library(it SHARED ${SRC})
# cuda_add_cublas_to_target(it) # cublas
# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed)
# # Python bindings
# pybind11_add_module(infini MODULE ${FFI})
# target_link_libraries(infini PRIVATE infini_cpp)
function(build_test files)
# Non-recursive glob for skip failed tests
file(GLOB TEST_SOURCES ${files})
foreach(testsourcefile ${TEST_SOURCES})
get_filename_component(testname ${testsourcefile} NAME_WE)
add_executable(${testname} ${testsourcefile})
target_link_libraries(${testname} InfiniTensor GTest::gtest_main)
add_test(NAME ${testname} COMMAND ${testname})
endforeach(testsourcefile ${TEST_SOURCES})
endfunction()
if(BUILD_TEST)
enable_testing()
if(BUILD_TEST_CORE)
build_test(test/core/*.cc)
endif()
if(BUILD_TEST_PET)
build_test(test/pet/*.cc)
endif()
if(BUILD_TEST_EINNET)
build_test(test/nnet/*.cc)
endif()
endif()

61
include/core/common.h Normal file
View File

@ -0,0 +1,61 @@
#pragma once
#include <cassert>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <optional>
#include <set>
#include <sstream>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace infini {
using std::list;
using std::map;
using std::pair;
using std::set;
using std::string;
using std::tie;
using std::to_string;
using std::tuple;
using std::unordered_map;
using std::vector;
// Aliases
using dtype = float;
using HashType = size_t; // compatible with std::hash
// Metaprogramming utilities
#define _CAT(A, B) A##B
#define _SELECT(NAME, NUM) _CAT(NAME##_, NUM)
#define _GET_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ...) COUNT
#define _VA_SIZE(...) _GET_COUNT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1)
#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__)
// Assert: conditions should have no side effect
#define _IT_ASSERT_2(name, info) \
(static_cast<bool>(name) \
? void(0) \
: throw std::runtime_error( \
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
"] Assertion failed (" + #name + "): " + #info))
#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, "");
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
#define IT_TODO_HALT() IT_ASSERT(false, "Unimplemented")
#define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__)
// Other utilities
// std::to_underlying is avaiable since C++23
template <typename T> auto enum_to_underlying(T e) {
return static_cast<std::underlying_type_t<T>>(e);
}
double timeit(const std::function<void()> &func);
} // namespace infini

47
include/core/graph.h Normal file
View File

@ -0,0 +1,47 @@
#pragma once
#include "core/operator.h"
#include "core/tensor.h"
namespace infini {
// TODO: graph should be attached to a context
class GraphNode : public Object {
protected:
TensorVec tensors;
TensorVec inputs;
TensorVec outputs;
OpVec ops;
public:
// Graph(OpVec oplist);
string toString() const override;
void addOp(Operator op) { ops.push_back(op); };
const TensorVec &getTensors() const { return tensors; }
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }
const OpVec &getOperators() const { return ops; }
// TensorVec &getInputs();
// TensorVec &getOutputs();
Tensor addTensor(Shape dim, DataType dtype = DataType::Int32) {
Tensor tensor = make_ref<TensorNode>(dim, dtype);
tensors.emplace_back(tensor);
return tensor;
}
void dataMalloc();
private:
// TODO: updateConnection
/**
* @brief Add reverse connections and Op relationship in ctor.
*/
void updateConnection();
// TODO: move to another class
// bool exportOnnx(const char *path);
// bool importOnnx(const char *net);
};
} // namespace infini

76
include/core/kernel.h Normal file
View File

@ -0,0 +1,76 @@
#pragma once
#include "core/common.h"
#include "core/operator.h"
#include "core/tensor.h"
namespace infini {
struct PerfRecord {
double time; // in milliseconds
};
class Kernel {
public:
Kernel() {}
virtual ~Kernel() {}
/**
* @param op The operator to be executed.
* @param record The parameters for kernel execution. If extra parameters
* are required, inherit from PerfRecord and add extra parameters.
* Otherwire, use PerfRecord directly.
*/
virtual void compute(const Operator &op,
const PerfRecord &record) const = 0;
/**
* @brief Executes an op with a default parameter.
*/
virtual void compute(const Operator &op) const = 0;
// Premise: op is idempotent since it is called multiple times.
virtual PerfRecord tune(const Operator &op) const = 0;
};
class KernelRegistry {
public:
using KernelRecord =
tuple<Kernel *const, const string, const int>; // Kernel, name, ID
private:
std::map<KernelAttrs, KernelRecord> kernels;
int nKernels = 0;
public:
~KernelRegistry() {
for (auto &[k, v] : kernels)
delete std::get<0>(v);
}
static KernelRegistry &getInstance() {
static KernelRegistry instance;
return instance;
}
bool registerKernel(const KernelAttrs &key, Kernel *kernel, string name) {
// TODO: mutliple kernels support: priority and check name
IT_ASSERT(kernels.find(key) == kernels.end(),
"Kernel already registered");
kernels.emplace(key, KernelRecord{kernel, name, ++nKernels});
return true;
}
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
return std::get<0>(kernels.at(kernelAttrs));
}
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
return kernels.at(kernelAttrs);
}
};
} // namespace infini
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
namespace infini { \
static const bool _CAT(_register_kernel_, cnt) = \
KernelRegistry::getInstance().registerKernel( \
KernelAttrs{device, opType, dataType}, new kernel(), name); \
}
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)

19
include/core/mutator.h Normal file
View File

@ -0,0 +1,19 @@
#pragma once
#include "core/graph.h"
namespace infini {
class Mutator {
private:
int candidatesLimit;
// // Statistical data
// int numTotalCandidates;
public:
Mutator(int candidatesLimit) : candidatesLimit(candidatesLimit){};
virtual ~Mutator(){};
virtual vector<Graph> run(const Graph &in_graph) = 0;
};
} // namespace infini

54
include/core/object.h Normal file
View File

@ -0,0 +1,54 @@
#pragma once
#include "core/common.h"
#include "ref.h"
namespace infini {
using GuidBaseType = int;
class Guid {
private:
GuidBaseType guid;
private:
GuidBaseType generateGuid() {
static GuidBaseType guidCnt = 0;
return ++guidCnt;
}
public:
Guid() { guid = generateGuid(); }
Guid(const Guid &rhs) { guid = generateGuid(); }
Guid &operator=(const Guid &rhs) {
guid = generateGuid();
return *this;
}
operator GuidBaseType() const { return guid; }
};
class Object {
protected:
Guid guid;
public:
virtual ~Object(){};
virtual string toString() const = 0;
void print() { std::cout << toString() << std::endl; }
Guid getGuid() const { return guid; }
};
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
os << obj.toString();
return os;
}
// Overload for Ref-wrapped Object
template <typename T,
typename std::enable_if_t<std::is_base_of_v<Object, T>> * = nullptr>
inline std::ostream &operator<<(std::ostream &os, const Ref<T> &obj) {
os << obj->toString();
return os;
}
} // namespace infini

180
include/core/operator.h Normal file
View File

@ -0,0 +1,180 @@
#pragma once
#include "core/tensor.h"
namespace infini {
enum class OpType {
Unknown = 0,
// linear
Conv = 100,
Matmul,
ConvTrans,
G2BMM,
GBMML,
Pad,
Slice,
Concat,
Split,
Transpose,
Extend,
MaxPool,
AvgPool,
Add,
Sub,
Mul,
Div,
Pow,
Gather,
ReduceMean,
Reshape,
Identity,
// element wise
BatchNorm = 200,
Softmax,
Activation,
Resize,
//
MemBound = 300,
};
enum class Device { CPU = 1, CUDA };
using KernelAttrs = std::tuple<Device, OpType, DataType>;
class OpRegistry {
public:
static std::string getOpName(OpType opType) {
#define FOP(op) \
case OpType::op: \
return #op
switch (opType) {
FOP(Unknown);
// linear
FOP(Conv);
FOP(Matmul);
FOP(ConvTrans);
FOP(G2BMM);
FOP(GBMML);
FOP(Pad);
FOP(Slice);
FOP(Concat);
FOP(Split);
FOP(Transpose);
FOP(Extend);
FOP(MaxPool);
FOP(AvgPool);
FOP(Add);
FOP(Sub);
FOP(Mul);
FOP(Div);
FOP(Pow);
FOP(Gather);
FOP(ReduceMean);
FOP(Reshape);
FOP(Identity);
// element wise
FOP(BatchNorm);
FOP(Softmax);
FOP(Activation);
//
FOP(MemBound);
default:
IT_ASSERT(false);
break;
}
#undef FOP
}
};
enum class ActType {
None,
Relu,
Sigmoid,
Tanh,
};
struct OpPerfKey {
HashType hash;
OpType opType;
vector<int> attrs;
public:
OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {})
: hash(hash), opType(opType), attrs(attrs) {}
bool operator==(const OpPerfKey &rhs) const {
if (hash != rhs.hash)
return false;
if (opType != rhs.opType)
return false;
if (attrs != rhs.attrs)
return false;
return true;
}
// TODO: remove this function after we use unordered_map in PerfEngine
bool operator<(const OpPerfKey &rhs) const {
if (hash != rhs.hash)
return hash < rhs.hash;
if (opType != rhs.opType)
return opType < rhs.opType;
if (attrs.size() != rhs.attrs.size())
return attrs.size() < rhs.attrs.size();
for (size_t i = 0; i < attrs.size(); ++i)
if (attrs[i] != rhs.attrs[i])
return attrs[i] < rhs.attrs[i];
return false;
}
};
class OperatorNode : public Object {
friend class Kernel;
protected:
OpType type;
TensorVec inputs;
TensorVec outputs;
// vector<WRef<Operator>> predecessors;
// vector<WRef<Operator>> successors;
public:
OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs)
: type(opType), inputs(inputs), outputs(outputs) {}
virtual vector<Shape> computeShape() const = 0;
virtual OpPerfKey getOpPerfKey() const = 0;
public: // check Op type
bool isLinearOp() const;
bool isElementWiseOp() const;
bool isSplitOp() const;
bool isConcatOp() const;
bool isComputeOp() const;
bool isTransposeOp() const;
bool isReshapeOp() const;
bool isMemBoundOp() const;
public: // getter and setter
// TensorVec getInputs() { return inputs; }
const TensorVec &getInputs() const { return inputs; }
// TensorVec getOutputs() { return outputs; }
const TensorVec &getOutputs() const { return outputs; }
Tensor getInputs(size_t i) { return inputs.at(i); }
Tensor getOutput() const {
IT_ASSERT(outputs.size() == 1, "Unimplemented");
return outputs[0];
}
OpType getOpType() const { return type; }
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;
virtual HashType hash() const { IT_TODO_HALT(); }
virtual HashType hashWithShape() const { IT_TODO_HALT(); }
};
} // namespace infini
namespace std {
template <> struct hash<infini::OpPerfKey> {
size_t operator()(const infini::OpPerfKey &key) const { return key.hash; }
};
} // namespace std

View File

@ -0,0 +1,36 @@
#pragma once
#include "core/graph.h"
#include "core/kernel.h"
namespace infini {
class PerfEngine {
public:
// TODO: Key should be OpPerfKey + Context(maybe implicat) to support
// multiple candiate kernels.
using Key = std::pair<KernelAttrs, OpPerfKey>;
private:
map<Key, PerfRecord> data;
public:
static PerfEngine &getInstance() {
static PerfEngine instance;
return instance;
}
std::optional<PerfRecord> getPerfData(const Key &key) {
auto it = data.find(key);
if (it != data.end()) // find previous evaluating results
return data.at(key);
else
return std::nullopt;
}
void setPerfData(const Key &key, const PerfRecord &record) {
IT_ASSERT(data.find(key) == data.end(), "Perf data already exist");
data.emplace(key, record);
}
};
} // namespace infini

35
include/core/ref.h Normal file
View File

@ -0,0 +1,35 @@
#pragma once
#include <functional> // hash
#include <memory>
#include <type_traits>
namespace infini {
template <typename T> using Ref = std::shared_ptr<T>;
template <typename T> using WRef = std::weak_ptr<T>;
template <typename T> struct is_ref : std::false_type {};
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
template <typename T> struct is_ref<WRef<T>> : std::true_type {};
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
static_assert(is_ref<T>::value == false, "Ref should not be nested");
return std::make_shared<T>(std::forward<Params>(params)...);
}
template <class T, class U,
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
Ref<T> as(const Ref<U> &ref) {
return std::dynamic_pointer_cast<T>(ref);
}
template <typename T>
std::vector<WRef<T>> get_wref_vec(const std::vector<Ref<T>> &vec) {
std::vector<WRef<T>> wref_vec;
wref_vec.reserve(vec.size());
for (const auto &ref : vec)
wref_vec.emplace_back(ref);
return wref_vec;
}
} // namespace infini

26
include/core/run_enigne.h Normal file
View File

@ -0,0 +1,26 @@
#pragma once
#include "core/graph.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
namespace infini {
class RunEngine {
private:
Device device;
public:
RunEngine(Device device) : device(device) {}
~RunEngine() {}
void run(const Graph &graph, bool tune = false,
bool profiling = false) const;
double getPerfTime(const Graph &graph, bool profiling = false) const;
private:
void printProfilingData(double totTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const;
};
} // namespace infini

180
include/core/tensor.h Normal file
View File

@ -0,0 +1,180 @@
#pragma once
#include "core/tensor_base.h"
namespace infini {
// TODO: how to deal with this
using ShapeElem = int;
using Shape = vector<ShapeElem>;
class TensorNode : public TensorBaseNode {
private:
Shape shape;
public:
TensorNode(const Shape &shape, DataType dtype);
virtual ~TensorNode() {}
string toString() const override;
size_t size() const;
void dataMalloc();
Shape getDims() const { return shape; }
size_t getOffset(const Shape &ds) const;
using TensorBaseNode::getData;
VType getData(const Shape &pos) const;
void copyData(VType *dptr);
void printData() const;
bool equalData(const Tensor &rhs) const;
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {
// if (data == nullptr)
// data = new VType[size()];
// if (!random_inited)
// initFastrand();
// // srand(seed);
// // faster rand generator; parallel
// size_t iEnd = size();
// // std::cerr << "Init beginned " << std::endl;
// #pragma omp parallel for
// for (size_t i = 0; i < iEnd; ++i)
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
// 10000;
// // std::cerr << "Init finished" << std::endl;
// computed = ComputedFull;
// return true;
// }
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
// data[0] = val;
// return true;
// }
// bool setData(const Dim &ds, VType val) {
// if (data == nullptr || ds.size() != dims.size())
// return false;
// data[getOffset(ds)] = val;
// return true;
// }
// bool setData(size_t pos, VType val) {
// if (data == nullptr || pos >= size())
// return false;
// data[pos] = val;
// return true;
// }
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
// VType getBroadcastData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getBroadcastOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getBroadcastData(size_t pos) {
// assert(data != nullptr);
// return data[pos % size()];
// }
// size_t getBroadcastOffset(const Dim &ds) {
// assert(ds.size() >= dims.size());
// auto nDim = dims.size();
// auto nBroadcastDim = ds.size() - nDim;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
// dims[i])
// return (size_t)-1;
// size_t idx = 0;
// for (size_t i = 0; i < nDim; ++i)
// idx = idx * dims[i] + ds[nBroadcastDim + i];
// return idx;
// }
// void itInit() { it = Dim(dims.size(), 0); }
// void itReset() {
// itInit();
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// it[i] = 0;
// }
// bool itValid() {
// if (it.size() != dims.size())
// return false;
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// if (it[i] >= dims[i])
// return false;
// return true;
// }
// const Dim &itGet() { return it; }
// void itNext() {
// auto p = it.size() - 1;
// it[p] += 1;
// while (p >= 1) {
// if (it[p] == dims[p]) {
// it[p] = 0;
// it[--p] += 1;
// } else
// break;
// }
// }
// TensorType getType() const { return type; }
// void setType(TensorType ty) { type = ty; }
// static inline void initFastrand() {
// assert(omp_get_max_threads() <= 256);
// // srand(0); // constant seed for test
// // align random_seed to avoid false sharing
// for (int i = 0; i < 256 * 16; ++i) {
// // random_seed[i] = rand();
// // constant random seed for test
// random_seed[i] = i;
// }
// random_inited = true;
// }
// static inline int fastrand(int &g_seed) {
// g_seed = (214013 * g_seed + 2531011);
// return (g_seed >> 16) & 0x7FFF;
// }
// std::vector<std::vector<int>> const *getSplittingPoints() const {
// assert(!splittingPoints.empty());
// return &splittingPoints;
// }
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
// assert(!value.empty());
// splittingPoints = value;
// return true;
// }
// void printSplittingPoints() {
// if (splittingPoints.empty())
// printf("Empty SplittingPoints");
// else {
// printf("[");
// for (auto &vs : splittingPoints) {
// printf("[");
// for (auto v : vs)
// printf("%2d,", v);
// printf("],");
// }
// printf("]");
// }
// }
// void initSplittingPoints() {
// splittingPoints.resize(getDims().size()); }
// void printShape();
};
} // namespace infini

261
include/core/tensor_base.h Normal file
View File

@ -0,0 +1,261 @@
#pragma once
#include "core/object.h"
#include "core/ref.h"
namespace infini {
// class Tensor;
class TensorBaseNode;
class TensorNode;
class OperatorNode;
class GraphNode;
using TensorBase = Ref<TensorBaseNode>;
using Tensor = Ref<TensorNode>;
using Operator = Ref<OperatorNode>;
using Graph = Ref<GraphNode>;
using TensorVec = vector<Tensor>;
using OpVec = vector<Operator>;
using VType = uint32_t;
enum class DataType {
Float32,
Int32,
};
class TensorBaseNode : public Object {
public:
// enum TensorType {
// Input,
// Weight,
// Invalid,
// NotCounted,
// };
protected:
int dim;
DataType dtype;
vector<WRef<TensorBaseNode>> inputOf;
WRef<TensorBaseNode> outputOf;
// TODO: Ref<void> -> Ref<Blob>
Ref<VType[]> data;
// ComputeState computed;
// static int random_seed[256 * 16];
// static bool random_inited;
public:
TensorBaseNode(int dim, DataType dtype);
virtual ~TensorBaseNode() {}
Ref<VType[]> getDataPtr() const { return data; }
VType getData(size_t offset) const;
DataType getDType() const { return dtype; }
// uint64_t getHash() const { return hash; }
// void setInputOf(const OpVec &ops) {
// inputOf.clear();
// for (const auto &op : ops)
// inputOf.emplace_back(op);
// }
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
// void setOutputOf(Operator op) { outputOf = op; }
// const OpVec &getInputOf() { return inputOf; }
// Operator *getOutputOf() { return outputOf; }
// std::pair<Operator *, int> getOutputOfWithIndex();
// const Dim &getDims() const { return dims; }
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {
// if (data == nullptr)
// data = new VType[size()];
// if (!random_inited)
// initFastrand();
// // srand(seed);
// // faster rand generator; parallel
// size_t iEnd = size();
// // std::cerr << "Init beginned " << std::endl;
// #pragma omp parallel for
// for (size_t i = 0; i < iEnd; ++i)
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
// 10000;
// // std::cerr << "Init finished" << std::endl;
// computed = ComputedFull;
// return true;
// }
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
// data[0] = val;
// return true;
// }
// bool setData(const Dim &ds, VType val) {
// if (data == nullptr || ds.size() != dims.size())
// return false;
// data[getOffset(ds)] = val;
// return true;
// }
// bool setData(size_t pos, VType val) {
// if (data == nullptr || pos >= size())
// return false;
// data[pos] = val;
// return true;
// }
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
// VType getData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getData(size_t pos) {
// assert(data != nullptr);
// assert(pos < size());
// return data[pos];
// }
// VType *getDataPtr() const { return data; }
// size_t getOffset(const Dim &ds) {
// auto nDim = ds.size();
// assert(dims.size() == nDim);
// if (ds.empty())
// return 0;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[i] < 0 || ds[i] >= dims[i])
// return (size_t)-1;
// size_t idx = ds[0];
// size_t dm = 0;
// while (++dm < nDim)
// idx = idx * dims[dm] + ds[dm];
// return idx;
// }
// VType getBroadcastData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getBroadcastOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getBroadcastData(size_t pos) {
// assert(data != nullptr);
// return data[pos % size()];
// }
// size_t getBroadcastOffset(const Dim &ds) {
// assert(ds.size() >= dims.size());
// auto nDim = dims.size();
// auto nBroadcastDim = ds.size() - nDim;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
// dims[i])
// return (size_t)-1;
// size_t idx = 0;
// for (size_t i = 0; i < nDim; ++i)
// idx = idx * dims[i] + ds[nBroadcastDim + i];
// return idx;
// }
// void itInit() { it = Dim(dims.size(), 0); }
// void itReset() {
// itInit();
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// it[i] = 0;
// }
// bool itValid() {
// if (it.size() != dims.size())
// return false;
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// if (it[i] >= dims[i])
// return false;
// return true;
// }
// const Dim &itGet() { return it; }
// void itNext() {
// auto p = it.size() - 1;
// it[p] += 1;
// while (p >= 1) {
// if (it[p] == dims[p]) {
// it[p] = 0;
// it[--p] += 1;
// } else
// break;
// }
// }
// size_t size() const {
// size_t sz = 1;
// auto dm = dims.size();
// while (dm > 0)
// sz *= dims[--dm];
// return sz;
// }
// TensorType getType() const { return type; }
// void setType(TensorType ty) { type = ty; }
// static inline void initFastrand() {
// assert(omp_get_max_threads() <= 256);
// // srand(0); // constant seed for test
// // align random_seed to avoid false sharing
// for (int i = 0; i < 256 * 16; ++i) {
// // random_seed[i] = rand();
// // constant random seed for test
// random_seed[i] = i;
// }
// random_inited = true;
// }
// static inline int fastrand(int &g_seed) {
// g_seed = (214013 * g_seed + 2531011);
// return (g_seed >> 16) & 0x7FFF;
// }
// std::vector<std::vector<int>> const *getSplittingPoints() const {
// assert(!splittingPoints.empty());
// return &splittingPoints;
// }
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
// assert(!value.empty());
// splittingPoints = value;
// return true;
// }
// void printSplittingPoints() {
// if (splittingPoints.empty())
// printf("Empty SplittingPoints");
// else {
// printf("[");
// for (auto &vs : splittingPoints) {
// printf("[");
// for (auto v : vs)
// printf("%2d,", v);
// printf("],");
// }
// printf("]");
// }
// }
// void initSplittingPoints() {
// splittingPoints.resize(getDims().size()); }
// void printShape();
};
} // namespace infini

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class MatchComputationKernel : public Pass {
public:
MatchComputationKernel(Derivator &derivator)
: Pass(derivator, "MatchComputationKernel") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class MatchMemBoundKernel : public Pass {
public:
MatchMemBoundKernel(Derivator &derivator)
: Pass(derivator, "MatchMemBoundKernel") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

41
include/nnet/Pass/Pass.h Normal file
View File

@ -0,0 +1,41 @@
#pragma once
#include "nnet/derivator.h"
namespace nnet {
class Pass {
private:
VecExpr transformations;
protected:
Derivator &derivator;
string passName;
/**
* @brief // False if does not add log in Derivator. It should be false for
* single Pass test to avoid mismatch of passInfos and passMsgs due to
* different number of "run" and "nextStep".
*/
bool enableLogging, enableDebug;
virtual void transform(Formula &origin, int depth, Expr &rCur) = 0;
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
const string &ruleInfo = "");
Var getNewVar();
string newTensorName();
private:
void initialize(Formula &origin, const Expr &rCur);
void finalize();
public:
Pass(Derivator &derivator, const string &passName);
virtual ~Pass();
void run(Formula &origin, int dfsDepth, Expr &rCur);
void setEnableLogging(bool value);
void setEnableDebug(bool value);
const VecExpr &getTransformations();
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/Pass/Pass.h"
#include "nnet/ReplaceKit.h"
namespace nnet {
class Rule1VariableSplit : public Pass {
public:
Rule1VariableSplit(Derivator &derivator)
: Pass(derivator, "Rule1VariableSplit") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
vector<Replace> getSplitableVar(const RangeOp &rangeOp);
Expr replaceIters(Expr cur, const Replace &replace);
};
} // namespace nnet

View File

@ -0,0 +1,29 @@
#pragma once
#include "nnet/Pass/Pass.h"
#include "nnet/ReplaceKit.h"
namespace nnet {
class Rule2VariableMerging : public Pass {
private:
map<int, vector<Var>> substituteRules;
public:
Rule2VariableMerging(Derivator &derivator)
: Pass(derivator, "Rule2VariableMerging") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
vector<Replace> getMergableReplaces(RangeOp rangeOp, int depth);
optional<Replace> getReplaceMergingTwoLoopIters(const RangeOp &rangeOp,
pair<Iterator, int> pairA,
pair<Iterator, int> pairB,
const IteratorTable &exprIT,
int tensorID);
optional<Replace> getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
pair<Iterator, int> pa,
pair<Iterator, int> pb);
};
} // namespace nnet

View File

@ -0,0 +1,19 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule3StageSplit : public Pass {
private:
map<int, vector<Var>> substituteRules;
public:
Rule3StageSplit(Derivator &derivator)
: Pass(derivator, "Rule3StageSplit") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
vector<vector<Var>> getSplitSummationIters(RangeOp rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,22 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule4StageMerging : public Pass {
bool success, mergeStageWithCalc;
public:
Rule4StageMerging(Derivator &derivator)
: Pass(derivator, "Rule4StageMerging"), success(false),
mergeStageWithCalc(false) {}
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
bool mergeStageWithCalc = false);
bool isSuccessful();
void setMergeStageWithCalc(bool value);
private:
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule5RangeRelaxation : public Pass {
public:
Rule5RangeRelaxation(Derivator &derivator)
: Pass(derivator, "Rule5RangeRelaxation") {}
Expr rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
private:
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,17 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule6KenerlMatching : public Pass {
public:
Rule6KenerlMatching(Derivator &derivator)
: Pass(derivator, "Rule6KenerlMatching") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
// RE: seperating this func is a choice.
VecExpr matchElementWise(const RangeOp &rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule7DLT : public Pass {
public:
Rule7DLT(Derivator &derivator) : Pass(derivator, "Rule7DLT") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
Expr buildDLTSingleRangeOp(const RangeOp &original, const Expr &newSummand);
vector<int> getFactors();
};
} // namespace nnet

View File

@ -0,0 +1,48 @@
#pragma once
#include "nnet/Pass/Pass.h"
#include "nnet/ReplaceKit.h"
namespace nnet {
class Rule8GuidedDLT : public Pass {
public:
Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {}
VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur,
bool debug = false);
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
/**
* @brief If only one row miss match (more iterators mismatch), directly do
* data layout construction according to the IT.
*
* @return Expr Return nullptr if failed.
*/
Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch,
const IteratorTable &exprIT, const Pattern &pattern);
/**
* @brief Check whether two iterators overlap each other. If overlapping, we
* cannot simply reconstruct the tensor into a new one by seperate all
* iterators into different dimensions.
*/
bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT,
int tensorID);
/**
* @brief Only product of two tensors can be guided DLTed.
*
* @param cur
* @return true
* @return false
*/
bool statisfyGuidedDLT(RangeOp cur) const;
/**
* @brief Deal with output DLT mismatch only.
*/
Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch,
const IteratorTable &exprIT,
const Pattern &pattern);
Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace,
vector<Var> tensorDimAxes, vector<int> newShape);
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule90TwoStageElementWise : public Pass {
public:
Rule90TwoStageElementWise(Derivator &derivator)
: Pass(derivator, "Rule90TwoStageElementWise") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
VecExpr matchTwoStageElementWise(const RangeOp &rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule91MergeStagesWithSum : public Pass {
public:
Rule91MergeStagesWithSum(Derivator &derivator)
: Pass(derivator, "Rule91MergeStagesWithSum") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule9RangeMagnify : public Pass {
public:
Rule9RangeMagnify(Derivator &derivator)
: Pass(derivator, "Rule9RangeMagnify") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

46
include/nnet/ReplaceKit.h Normal file
View File

@ -0,0 +1,46 @@
#pragma once
#include "nnet/expr.h"
namespace nnet {
struct Replace {
int iteratorType;
vector<Var> oldIters; // i_1, ...
vector<Var> newIters; // j_1, ...
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
VecExpr psis; // i_1=\psi_1(j_1, ...)
vector<VarRangePair> newVarRanges;
bool isReplaced(Var var) const {
for (const auto &iter : oldIters)
if (iter->equal(var))
return true;
return false;
}
string toReadable() const {
string ret = "Old iters: " + serializeVec(oldIters) +
", new iters: " + serializeVec(newIters);
ret += " phis: " + serializeVec(phis) + " psis: " + serializeVec(psis);
return ret;
}
};
class ReplaceKit {
public:
static RangeOp replaceRangeOpIterator(const RangeOp &rangeOp,
const Replace &replace,
const Expr &replacedSummand);
static Subscript buildSubscirptForLoopVarReplace(const RangeOp &inner,
const Replace &replace);
static RangeOp buildDLTOuterRangeOp(const RangeOp &original,
const Subscript &subscriptedNewRangeOp);
static Expr replaceMultipleExprs(const Expr &cur,
const vector<Var> &patterns,
const VecExpr &replacements,
bool simplify = true);
static Expr replaceExpr(const Expr &cur, const Expr &pattern,
const Expr &replacement);
};
} // namespace nnet

View File

@ -0,0 +1,38 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
class AsTVMVisitor : public Functor<std::string(void)> {
private:
int nStage = 0, curStage = -1;
std::unordered_map<std::string, int> offset;
std::vector<std::string> inputs;
std::string output;
std::vector<std::string> pythonVars;
std::vector<std::vector<int>> inputShapes;
std::vector<int> outputShape;
std::string stmts;
public:
std::string getStmts() const;
const std::vector<std::string> &getInputs() const { return inputs; }
const std::string &getOutput() const { return output; }
const std::vector<std::vector<int>> &getInputShapes() const {
return inputShapes;
}
const std::vector<int> &getOutputShape() const { return outputShape; }
std::string visit_(const Constant &c) override;
std::string visit_(const BinaryOp &c) override;
std::string visit_(const Func &c) override;
std::string visit_(const RangeOp &c) override;
std::string visit_(const Subscript &c) override;
std::string visit_(const Var &c) override;
std::string visit_(const Tensor &c) override;
};
} // namespace nnet

View File

@ -0,0 +1,22 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class CheckOOBVisitor : public ExprTreeVisitor {
RangeOp rangeOp;
bool detect = false;
public:
CheckOOBVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 0, 0, _verobse) {}
void visit_(const Subscript &c) override;
/**
* @brief
* @return true If there is OOB
* @return false If there is no OOB
*/
bool checkRangeOp(const RangeOp &_rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Clone ExprNodes in a stage except Tensor, Var, and Constant nodes.
class CloneMutator : public Mutator {
public:
CloneMutator() : Mutator(false) {}
Expr visit_(const Constant &c) override;
Expr visit_(const Var &c) override;
Expr visit_(const Tensor &c) override;
Expr clone(const Expr &c) { return dispatch(c); }
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class CompareMultiFormulasVisitor : public ExprTreeVisitor {
vector<VarRangePair> newSumVarRanges;
RangeOp newRangeOp;
public:
CompareMultiFormulasVisitor() : ExprTreeVisitor() {}
bool compare(const VecExpr &roots);
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class CountRoutineVisitor : public ExprTreeVisitor {
private:
vector<int> cnts;
public:
CountRoutineVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 1, _verobse) {}
void visit_(const Tensor &c) override;
vector<int> count(const Expr &root);
bool match(const Expr &root, int nMatmul = 0, int nConv = 0,
int nElement = 0, int nSg2bmm = 0, int nLongformerGBMM = 0);
};
} // namespace nnet

View File

@ -0,0 +1,25 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class FullPrinterVisitor : public ExprTreeVisitor {
private:
vector<tuple<string, Routine, Tensor>> q;
public:
FullPrinterVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
void visit_(const Tensor &c) override;
string print(const Expr &root);
/**
* @brief Get all tensors & OPs in a reversed order
*
* @param root
* @return vector<<Output TensorName, RoutineNode, output tensor in NNet>>
*/
const vector<tuple<string, Routine, Tensor>> &traverse(const Expr &root);
};
} // namespace nnet

View File

@ -0,0 +1,22 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Get all tensors in the stage
class GetTensorsVisitor : public ExprTreeVisitor {
private:
unordered_map<string, Tensor> tensors;
void visit_(const Tensor &c) override;
public:
GetTensorsVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
auto get(const Expr &c) {
dispatch(c);
return tensors;
}
};
} // namespace nnet

View File

@ -0,0 +1,31 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Calculate hash for a normal form, starting at a RangeOp
class HashVisitor : public Functor<HashType(void)> {
inline const static HashType BKDR_SEED[] = {131, 313, 10007, 65599};
PtrUmap<Iterator, int> varHash;
int nLoopVars = 0;
PtrUmap<Iterator, int> name2id;
vector<int> rootId;
vector<bool> haveAlias;
int nVars = 0;
vector<HashType> power;
private:
HashType visit_(const Constant &c) override;
HashType visit_(const BinaryOp &c) override;
HashType visit_(const RangeOp &c) override;
HashType visit_(const Subscript &c) override;
HashType visit_(const Tensor &c) override;
HashType visit_(const Var &c) override;
public:
HashVisitor(int _verobse = 0) : Functor(_verobse) {}
HashType getHash(const Expr &c);
};
} // namespace nnet

View File

@ -0,0 +1,23 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class InputVisitor : public ExprTreeVisitor {
vector<Tensor> inputs;
public:
int nInputs = 0;
InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
void visit_(const Tensor &c) override;
/**
* @brief Get the all inputs in the netsed stages
*/
vector<Tensor> getInputs(const RangeOp &_rangeOp) {
dispatch(_rangeOp);
return inputs;
}
};
} // namespace nnet

View File

@ -0,0 +1,55 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class Interpreter : public Functor<int()> {
public:
using ttype = int; // Test data type
using rtype = int; // Return data type
using Position = vector<int>;
using Inputs = unordered_map<string, Ref<vector<ttype>>>;
using Iteration = PtrUmap<Var, int>;
private:
// cache the input value
Inputs inputs;
vector<Iteration> iterations;
vector<Position> positions;
rtype visit_(const Constant &c) override;
rtype visit_(const BinaryOp &c) override;
rtype visit_(const RangeOp &c) override;
rtype visit_(const Subscript &c) override;
rtype visit_(const Var &c) override;
rtype visit_(const Tensor &c) override;
// int visit_(const Func &c); // Future work
static Inputs genInputStartingFromZero(const RangeOp &range);
public:
Interpreter(Inputs _inputs, int _verbose = 0)
: Functor(_verbose), inputs(_inputs) {}
Interpreter(RangeOp range, int _verbose = 0);
/**
* @brief Calculate the output at specified poistions
*
* @param expr The expression to be calculated.
* @param poses Positions of output.
* @return vector<int> Value of output.
*/
vector<rtype> interpret(const Expr &expr, const vector<Position> &poses);
/**
* @brief Calculate the output at equally spaced positions
*
* @param expr The expression to be calculated.
* @param nPoses The number of calculated output positions.
* @return vector<int> Value of output.
*/
vector<rtype> interpretUniformSample(const RangeOp &range,
int nPoses = 100);
vector<rtype> interpretAllOutput(const RangeOp &range);
};
} // namespace nnet

View File

@ -0,0 +1,14 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class MatchReshapeVisitor : public Functor<bool(void)> {
private:
PtrMap<Iterator, int> _coefficient;
public:
bool visit_(const RangeOp &c) override;
};
} // namespace nnet

View File

@ -0,0 +1,60 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
class MatchTableVisitor
: public Functor<void(const Tensor &, int dim, optional<int> stride)> {
private:
// Var -> {(tensor, dim)}
Appearance appearance;
vector<Tensor> tensors;
vector<Subscript> subscripts;
StrideTable strideTable;
PtrMap<Iterator, vector<vector<int>>>
strideInDim; // [Iterator][tensorID][dim]=stride
// Intermediate variable
// product of a sub-exprtree: Stride has to be done in two DFS
SubexprSride subexprStride;
bool hasUnsupportedOp = false;
public:
MatchTableVisitor(int _verobse = 0) : Functor(_verobse) {}
void visit_(const BinaryOp &c, const Tensor &tensor, int dim,
optional<int> stride) override;
void visit_(const Subscript &c, const Tensor &tensor, int dim,
optional<int> stride) override;
void visit_(const Var &c, const Tensor &tensor, int dim,
optional<int> stride) override;
void visit_(const Constant &c, const Tensor &tensor, int dim,
optional<int> stride) override;
// void visit_(const Tensor &c, const Tensor &tensor) override;
[[nodiscard]] bool operator()(const RangeOp &e) {
hasUnsupportedOp = false;
// get the location and stride of each iterator
auto mulOp = as<BinaryOpNode>(e->getSummand());
// TODO [feature]: support complex index exprs
if (!mulOp || mulOp->getOpType() != OpType::Mul) {
nnet_unimplemented_continue();
return false;
}
StrideVisitor strideVisitor(0);
subexprStride = strideVisitor.getFormulaStride(e);
dispatch(mulOp->getLhs(), nullptr, 0, 0);
dispatch(mulOp->getRhs(), nullptr, 0, 0);
subscripts.emplace_back(as<SubscriptNode>(mulOp->getLhs()));
subscripts.emplace_back(as<SubscriptNode>(mulOp->getRhs()));
assert(tensors.size() == subscripts.size());
assert(tensors.size() < 5);
return !hasUnsupportedOp;
}
auto getResult() const {
return tuple(appearance, tensors, strideTable, subscripts);
}
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class MatmulTransposeMutator : public Mutator {
Derivator &derivator;
public:
MatmulTransposeMutator(Derivator &derivator)
: Mutator(1), derivator(derivator) {}
VecExpr transpose(const Tensor &tensor);
private:
Tensor transposeInput(const Tensor &tensor);
};
} // namespace nnet

View File

@ -0,0 +1,20 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class MergeMemboundMutator : public Mutator {
VecExpr kernels;
int curDepth; // from the last one to the first one
Expr visit_(const Tensor &c) override;
// FIXME: duplicate code
Expr rule4StageMerging(Expr &rCur, bool mergeStageWithCalc);
bool checkEmpty();
public:
MergeMemboundMutator(const VecExpr &kernels)
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
Expr merge(bool allowEmptyMembound = false);
};
} // namespace nnet

View File

@ -0,0 +1,43 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
/**
* @brief Since the output positions of operators always start from 0, we have
* to offset them if the the boundary expression of is not 0.
*/
class PatternMatcher : public Functor<void(void)> {
private:
Derivator &derivator;
bool hasNonZeroRange;
const RangeOp originalCur;
public:
PatternMatcher(Derivator &derivator, const RangeOp &cur);
/**
* @brief Get the Cur whose loop vars are all offset to [0, x). Since
* operator outputs start from 0, RangeOp has to be aligned.
*/
RangeOp getOffsetCur();
/**
* @brief Add outer RangeOp to map the original positions to the new
* positions staring from 0.
*
* @param exprs Tensors from matched exprs
*/
VecExpr applyWrapper(const VecExpr &exprs);
VecExpr matchWithPattern(const RangeOp &rangeOp, const Pattern &pattern);
private:
VecExpr matchKernel(const Pattern &pattern, const RangeOp &rangeOp,
IteratorTable &exprIT);
// get reverse tensor and iterator map ([pattern tensor/iter ID] ->
// real)
Expr matchKernelWithTensorMap(const Pattern &pattern,
const RangeOp &rangeOp,
IteratorTable &exprIT);
};
} // namespace nnet

View File

@ -0,0 +1,25 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class RangeMagnifyVisitor : public Mutator {
vector<VarRangePair> newSumVarRanges;
RangeOp newRangeOp;
public:
RangeMagnifyVisitor() : Mutator(0) {}
Expr visit_(const RangeOp &c) override;
Expr visit_(const Subscript &c) override;
/**
* @brief
*
* @param root
* @param _newSumVarRanges
* @return RangeOp nullptr if failed to magnify
*/
RangeOp magnify(const RangeOp &root,
const vector<VarRangePair> &_newSumVarRanges);
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
using RangeMap = PtrMap<Iterator, Range>;
class RangeRelaxFunctor : public Functor<RangeMap()> {
RangeOp rangeOp;
public:
RangeRelaxFunctor(RangeOp _rangeOp) : Functor(false), rangeOp(_rangeOp) {}
RangeMap visit_(const BinaryOp &c) override;
RangeMap visit_(const RangeOp &c) override;
RangeMap visit_(const Subscript &c) override;
RangeMap intersectRangeMaps(const RangeMap &a, const RangeMap &b);
};
} // namespace nnet

View File

@ -0,0 +1,20 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Replace node according to its address the summand
// Only subscript and tensor are supported now.
class ReplaceNodeMutator : public Mutator {
int nSubscripts = 0;
ExprNode *target;
Expr replacement;
public:
ReplaceNodeMutator() : Mutator(0) {}
Expr visit_(const Subscript &c) override;
Expr visit_(const Tensor &c) override;
Expr replace(const Expr &root, ExprNode *_target, const Expr &_replace);
};
} // namespace nnet

View File

@ -0,0 +1,33 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class ReplaceVariable : public Mutator {
VecExpr patterns, replacements;
map<HashType, int> patternHash;
public:
ReplaceVariable(Expr _pattern, Expr _replacement) : Mutator(false) {
set({_pattern}, {_replacement});
}
ReplaceVariable(const map<string, pair<Expr, Expr>> &mapping)
: Mutator(false) {
VecExpr _patterns, _replacements;
for (const auto &[_, v] : mapping) {
_patterns.emplace_back(v.first);
_replacements.emplace_back(v.second);
}
set(_patterns, _replacements);
}
Expr visit_(const BinaryOp &c) override;
// NOT recur to the next stage
Expr visit_(const RangeOp &c) override;
Expr visit_(const Var &c) override;
private:
void set(VecExpr _pattern, VecExpr _replacement);
Expr match(const Expr &c);
};
} // namespace nnet

View File

@ -0,0 +1,52 @@
#pragma once
#include "nlohmann/json_fwd.hpp"
#include "nnet/visitor.h"
#include <memory>
namespace nnet {
class Serializer : public Functor<string()> {
using json = nlohmann::ordered_json;
private:
static constexpr int VERSION{1};
std::unique_ptr<json> jPtr;
json &j;
static int id;
string visit_(const Constant &c) override;
string visit_(const BinaryOp &c) override;
string visit_(const RangeOp &c) override;
string visit_(const Subscript &c) override;
string visit_(const Var &c) override;
string visit_(const Tensor &c) override;
string dispatchRoutine(const Routine &c);
Expr buildExprTree(string key);
Routine buildRoutine(string key);
public:
Serializer(int _verobse = 0);
virtual ~Serializer();
/**
* @brief Serialize the given expression to json file
*
* @param expr The expression to be serialized
* @param filePath The path of json file to be output
* @param msg Message of derivation
* @return bool Whether the serialization succeed
*/
bool serialize(const Expr &expr, const string &filePath,
const string &msg = "");
/**
* @brief Deserialize the given json file to expression
*
* @param filePath The path to file to be deserialized
* @return Expression deserialized from the given json file
*/
Expr deserialize(const string &filePath);
};
} // namespace nnet

View File

@ -0,0 +1,38 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
// Simplify a index expression tree
class SimplifyExprVisitor : public Functor<void(optional<int> stride)> {
private:
SubexprSride subexprStride;
int constant;
PtrMap<Iterator, int> strides; // [var]=strides
map<pair<Iterator, int>, int, RefValueLess<pair<Iterator, int>>> divStrides,
modStrides; // 3*(i%8): [<i,8>]=3
// For divde and modulo with expr as dividend: 3*((i+1)%8): [<i+1,8>]=3
map<pair<Expr, int>, int, RefAddrLess<pair<Expr, int>>> divExprStrides,
modExprStrides;
public:
SimplifyExprVisitor() : Functor(0) {}
void visit_(const BinaryOp &c, optional<int> stride) override;
void visit_(const Var &c, optional<int> stride) override;
void visit_(const Constant &c, optional<int> stride) override;
PtrMap<Iterator, int> getStrides(const Expr &expr);
// TODO [refactor]: move this to SimplifyFormulaMutator as a member func
// this class should be get coefficients in a expr
Expr simplify(const Expr &expr);
int getConstant(const Expr &expr);
pair<PtrMap<Iterator, int>, int> getStridesConstant(const Expr &expr);
optional<Range> getExprRange(const Expr &expr, const RangeOp &rangeOp);
PtrMap<Iterator, int> getStrides() { return strides; }
const auto &getDivStrides() { return divStrides; }
const auto &getModStrides() { return modStrides; }
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
// Simplify all indexes in subscripts in an expression tree
class SimplifyFormulaMutator : public Mutator {
int nSubscripts = 0;
public:
SimplifyFormulaMutator() : Mutator(0) {}
Expr visit_(const Subscript &c) override;
// Expr visit_(const BinaryOp &c) override;
Expr simplify(const Expr &expr);
};
} // namespace nnet

View File

@ -0,0 +1,38 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
using SubexprSride = map<const ExprNode *, optional<int>>;
class StrideVisitor : public Functor<optional<int>(void)> {
private:
SubexprSride subexprStride;
public:
StrideVisitor(int _verobse = 0) : Functor(_verobse) {}
optional<int> visit_(const BinaryOp &c) override;
optional<int> visit_(const Subscript &c) override;
optional<int> visit_(const Var &c) override;
optional<int> visit_(const Constant &c) override;
// void visit_(const Tensor &c, const Tensor &tensor) override;
auto getFormulaStride(const RangeOp &e) {
subexprStride.clear();
// get the location and stride of each iterator
auto mulOp = as<BinaryOpNode>(e->getSummand());
// TODO [feature]: support complex index exprs
if (!mulOp || mulOp->getOpType() != OpType::Mul)
nnet_unimplemented_continue();
dispatch(mulOp->getLhs());
dispatch(mulOp->getRhs());
return subexprStride;
}
[[nodiscard]] auto getExprStride(const Expr &e) {
subexprStride.clear();
dispatch(e);
return subexprStride;
}
};
} // namespace nnet

77
include/nnet/common.h Normal file
View File

@ -0,0 +1,77 @@
#pragma once
#include "dbg.h"
#include <cassert>
#include <list>
#include <map>
#include <optional>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace nnet {
using std::dynamic_pointer_cast;
using std::endl;
using std::list;
using std::make_pair;
using std::make_shared;
using std::make_tuple;
using std::map;
using std::max;
using std::min;
using std::nullopt;
using std::optional;
using std::pair;
using std::set;
using std::shared_ptr;
using std::string;
using std::tie;
using std::to_string;
using std::tuple;
using std::unique_ptr;
using std::unordered_map;
template <typename T> using uset = std::unordered_set<T>;
using std::vector;
using std::weak_ptr;
// Aliases
using dtype = float;
using HashType = int;
template <typename T> struct ptr_less {
bool operator()(const T &lhs, const T &rhs) const { return *lhs < *rhs; }
};
template <typename T> struct ptr_hash {
size_t operator()(const T &lhs) const {
return std::hash<decltype(*lhs)>()(*lhs);
}
};
template <typename T> struct ptr_equal {
bool operator()(const T &lhs, const T &rhs) const { return *lhs == *rhs; }
};
static inline HashType genhash(HashType a, HashType b) {
return (a * 10007 + b + 12345) % 1000000007;
}
static inline HashType genhash(string s) {
HashType ret = 0;
for (auto c : s)
ret = genhash(ret, c);
return ret;
}
#define nnet_unimplemented_halt() \
{ assert(!"Unimplemented"); }
#define nnet_unimplemented_continue() \
{ dbg("Unimplemented"); }
#define nnet_assert(expr, msg) assert(((void)(msg), (expr)))
std::string pointer_to_hex(void *i);
} // namespace nnet

858
include/nnet/dbg.h Normal file
View File

@ -0,0 +1,858 @@
/*****************************************************************************
dbg(...) macro
License (MIT):
Copyright (c) 2019 David Peter <mail@david-peter.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*****************************************************************************/
#ifndef DBG_MACRO_DBG_H
#define DBG_MACRO_DBG_H
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
#define DBG_MACRO_UNIX
#elif defined(_MSC_VER)
#define DBG_MACRO_WINDOWS
#endif
// #ifndef DBG_MACRO_NO_WARNING
// #pragma message("WARNING: the 'dbg.h' header is included in your code base")
// #endif // DBG_MACRO_NO_WARNING
#include <algorithm>
#include <chrono>
#include <ctime>
#include <iomanip>
#include <ios>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include <vector>
#ifdef DBG_MACRO_UNIX
#include <unistd.h>
#endif
#if __cplusplus >= 201703L
#define DBG_MACRO_CXX_STANDARD 17
#elif __cplusplus >= 201402L
#define DBG_MACRO_CXX_STANDARD 14
#else
#define DBG_MACRO_CXX_STANDARD 11
#endif
#if DBG_MACRO_CXX_STANDARD >= 17
#include <optional>
#include <variant>
#endif
namespace dbg {
#ifdef DBG_MACRO_UNIX
inline bool isColorizedOutputEnabled() { return isatty(fileno(stderr)); }
#else
inline bool isColorizedOutputEnabled() { return true; }
#endif
struct time {};
namespace pretty_function {
// Compiler-agnostic version of __PRETTY_FUNCTION__ and constants to
// extract the template argument in `type_name_impl`
#if defined(__clang__)
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
static constexpr size_t PREFIX_LENGTH =
sizeof("const char *dbg::type_name_impl() [T = ") - 1;
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
#elif defined(__GNUC__) && !defined(__clang__)
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
static constexpr size_t PREFIX_LENGTH =
sizeof("const char* dbg::type_name_impl() [with T = ") - 1;
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
#elif defined(_MSC_VER)
#define DBG_MACRO_PRETTY_FUNCTION __FUNCSIG__
static constexpr size_t PREFIX_LENGTH =
sizeof("const char *__cdecl dbg::type_name_impl<") - 1;
static constexpr size_t SUFFIX_LENGTH = sizeof(">(void)") - 1;
#else
#error "This compiler is currently not supported by dbg_macro."
#endif
} // namespace pretty_function
// Formatting helpers
template <typename T> struct print_formatted {
static_assert(std::is_integral<T>::value,
"Only integral types are supported.");
print_formatted(T value, int numeric_base)
: inner(value), base(numeric_base) {}
operator T() const { return inner; }
const char *prefix() const {
switch (base) {
case 8:
return "0o";
case 16:
return "0x";
case 2:
return "0b";
default:
return "";
}
}
T inner;
int base;
};
template <typename T> print_formatted<T> hex(T value) {
return print_formatted<T>{value, 16};
}
template <typename T> print_formatted<T> oct(T value) {
return print_formatted<T>{value, 8};
}
template <typename T> print_formatted<T> bin(T value) {
return print_formatted<T>{value, 2};
}
// Implementation of 'type_name<T>()'
template <typename T> const char *type_name_impl() {
return DBG_MACRO_PRETTY_FUNCTION;
}
template <typename T> struct type_tag {};
template <int &...ExplicitArgumentBarrier, typename T>
std::string get_type_name(type_tag<T>) {
namespace pf = pretty_function;
std::string type = type_name_impl<T>();
return type.substr(pf::PREFIX_LENGTH,
type.size() - pf::PREFIX_LENGTH - pf::SUFFIX_LENGTH);
}
template <typename T> std::string type_name() {
if (std::is_volatile<T>::value) {
if (std::is_pointer<T>::value) {
return type_name<typename std::remove_volatile<T>::type>() +
" volatile";
} else {
return "volatile " +
type_name<typename std::remove_volatile<T>::type>();
}
}
if (std::is_const<T>::value) {
if (std::is_pointer<T>::value) {
return type_name<typename std::remove_const<T>::type>() + " const";
} else {
return "const " + type_name<typename std::remove_const<T>::type>();
}
}
if (std::is_pointer<T>::value) {
return type_name<typename std::remove_pointer<T>::type>() + "*";
}
if (std::is_lvalue_reference<T>::value) {
return type_name<typename std::remove_reference<T>::type>() + "&";
}
if (std::is_rvalue_reference<T>::value) {
return type_name<typename std::remove_reference<T>::type>() + "&&";
}
return get_type_name(type_tag<T>{});
}
inline std::string get_type_name(type_tag<short>) { return "short"; }
inline std::string get_type_name(type_tag<unsigned short>) {
return "unsigned short";
}
inline std::string get_type_name(type_tag<long>) { return "long"; }
inline std::string get_type_name(type_tag<unsigned long>) {
return "unsigned long";
}
inline std::string get_type_name(type_tag<std::string>) {
return "std::string";
}
template <typename T>
std::string get_type_name(type_tag<std::vector<T, std::allocator<T>>>) {
return "std::vector<" + type_name<T>() + ">";
}
template <typename T1, typename T2>
std::string get_type_name(type_tag<std::pair<T1, T2>>) {
return "std::pair<" + type_name<T1>() + ", " + type_name<T2>() + ">";
}
template <typename... T> std::string type_list_to_string() {
std::string result;
auto unused = {(result += type_name<T>() + ", ", 0)..., 0};
static_cast<void>(unused);
#if DBG_MACRO_CXX_STANDARD >= 17
if constexpr (sizeof...(T) > 0) {
#else
if (sizeof...(T) > 0) {
#endif
result.pop_back();
result.pop_back();
}
return result;
}
template <typename... T> std::string get_type_name(type_tag<std::tuple<T...>>) {
return "std::tuple<" + type_list_to_string<T...>() + ">";
}
template <typename T>
inline std::string get_type_name(type_tag<print_formatted<T>>) {
return type_name<T>();
}
// Implementation of 'is_detected' to specialize for container-like types
namespace detail_detector {
struct nonesuch {
nonesuch() = delete;
~nonesuch() = delete;
nonesuch(nonesuch const &) = delete;
void operator=(nonesuch const &) = delete;
};
template <typename...> using void_t = void;
template <class Default, class AlwaysVoid, template <class...> class Op,
class... Args>
struct detector {
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, void_t<Op<Args...>>, Op, Args...> {
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail_detector
template <template <class...> class Op, class... Args>
using is_detected =
typename detail_detector::detector<detail_detector::nonesuch, void, Op,
Args...>::value_t;
namespace detail {
namespace {
using std::begin;
using std::end;
#if DBG_MACRO_CXX_STANDARD < 17
template <typename T> constexpr auto size(const T &c) -> decltype(c.size()) {
return c.size();
}
template <typename T, std::size_t N>
constexpr std::size_t size(const T (&)[N]) {
return N;
}
#else
using std::size;
#endif
} // namespace
template <typename T>
using detect_begin_t = decltype(detail::begin(std::declval<T>()));
template <typename T>
using detect_end_t = decltype(detail::end(std::declval<T>()));
template <typename T>
using detect_size_t = decltype(detail::size(std::declval<T>()));
template <typename T> struct is_container {
static constexpr bool value =
is_detected<detect_begin_t, T>::value &&
is_detected<detect_end_t, T>::value &&
is_detected<detect_size_t, T>::value &&
!std::is_same<std::string,
typename std::remove_cv<typename std::remove_reference<
T>::type>::type>::value;
};
template <typename T>
using ostream_operator_t =
decltype(std::declval<std::ostream &>() << std::declval<T>());
template <typename T>
struct has_ostream_operator : is_detected<ostream_operator_t, T> {};
} // namespace detail
// Helper to dbg(…)-print types
template <typename T> struct print_type {};
template <typename T> print_type<T> type() { return print_type<T>{}; }
// Forward declarations of "pretty_print"
template <typename T>
inline void pretty_print(std::ostream &stream, const T &value, std::true_type);
template <typename T>
inline void pretty_print(std::ostream &, const T &, std::false_type);
template <typename T>
inline typename std::enable_if<!detail::is_container<const T &>::value &&
!std::is_enum<T>::value,
bool>::type
pretty_print(std::ostream &stream, const T &value);
inline bool pretty_print(std::ostream &stream, const bool &value);
inline bool pretty_print(std::ostream &stream, const char &value);
template <typename P>
inline bool pretty_print(std::ostream &stream, P *const &value);
template <typename T, typename Deleter>
inline bool pretty_print(std::ostream &stream,
std::unique_ptr<T, Deleter> &value);
// template <typename T>
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value);
template <size_t N>
inline bool pretty_print(std::ostream &stream, const char (&value)[N]);
template <>
inline bool pretty_print(std::ostream &stream, const char *const &value);
template <typename... Ts>
inline bool pretty_print(std::ostream &stream, const std::tuple<Ts...> &value);
template <>
inline bool pretty_print(std::ostream &stream, const std::tuple<> &);
template <> inline bool pretty_print(std::ostream &stream, const time &);
template <typename T>
inline bool pretty_print(std::ostream &stream, const print_formatted<T> &value);
template <typename T>
inline bool pretty_print(std::ostream &stream, const print_type<T> &);
template <typename Enum>
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
pretty_print(std::ostream &stream, Enum const &value);
inline bool pretty_print(std::ostream &stream, const std::string &value);
#if DBG_MACRO_CXX_STANDARD >= 17
inline bool pretty_print(std::ostream &stream, const std::string_view &value);
#endif
template <typename T1, typename T2>
inline bool pretty_print(std::ostream &stream, const std::pair<T1, T2> &value);
#if DBG_MACRO_CXX_STANDARD >= 17
template <typename T>
inline bool pretty_print(std::ostream &stream, const std::optional<T> &value);
template <typename... Ts>
inline bool pretty_print(std::ostream &stream,
const std::variant<Ts...> &value);
#endif
template <typename Container>
inline typename std::enable_if<detail::is_container<const Container &>::value,
bool>::type
pretty_print(std::ostream &stream, const Container &value);
// Specializations of "pretty_print"
template <typename T>
inline void pretty_print(std::ostream &stream, const T &value, std::true_type) {
stream << value;
}
template <typename T>
inline void pretty_print(std::ostream &, const T &, std::false_type) {
static_assert(detail::has_ostream_operator<const T &>::value,
"Type does not support the << ostream operator");
}
template <typename T>
inline typename std::enable_if<!detail::is_container<const T &>::value &&
!std::is_enum<T>::value,
bool>::type
pretty_print(std::ostream &stream, const T &value) {
pretty_print(stream, value,
typename detail::has_ostream_operator<const T &>::type{});
return true;
}
inline bool pretty_print(std::ostream &stream, const bool &value) {
stream << std::boolalpha << value;
return true;
}
inline bool pretty_print(std::ostream &stream, const char &value) {
const bool printable = value >= 0x20 && value <= 0x7E;
if (printable) {
stream << "'" << value << "'";
} else {
stream << "'\\x" << std::setw(2) << std::setfill('0') << std::hex
<< std::uppercase << (0xFF & value) << "'";
}
return true;
}
template <typename P>
inline bool pretty_print(std::ostream &stream, P *const &value) {
if (value == nullptr) {
stream << "nullptr";
} else {
stream << value;
}
return true;
}
template <typename T, typename Deleter>
inline bool pretty_print(std::ostream &stream,
std::unique_ptr<T, Deleter> &value) {
pretty_print(stream, value.get());
return true;
}
// template <typename T>
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value) {
// pretty_print(stream, value.get());
// stream << " (use_count = " << value.use_count() << ")";
// return true;
// }
template <size_t N>
inline bool pretty_print(std::ostream &stream, const char (&value)[N]) {
stream << value;
return false;
}
template <>
inline bool pretty_print(std::ostream &stream, const char *const &value) {
stream << '"' << value << '"';
return true;
}
template <size_t Idx> struct pretty_print_tuple {
template <typename... Ts>
static void print(std::ostream &stream, const std::tuple<Ts...> &tuple) {
pretty_print_tuple<Idx - 1>::print(stream, tuple);
stream << ", ";
pretty_print(stream, std::get<Idx>(tuple));
}
};
template <> struct pretty_print_tuple<0> {
template <typename... Ts>
static void print(std::ostream &stream, const std::tuple<Ts...> &tuple) {
pretty_print(stream, std::get<0>(tuple));
}
};
template <typename... Ts>
inline bool pretty_print(std::ostream &stream, const std::tuple<Ts...> &value) {
stream << "{";
pretty_print_tuple<sizeof...(Ts) - 1>::print(stream, value);
stream << "}";
return true;
}
template <>
inline bool pretty_print(std::ostream &stream, const std::tuple<> &) {
stream << "{}";
return true;
}
template <> inline bool pretty_print(std::ostream &stream, const time &) {
using namespace std::chrono;
const auto now = system_clock::now();
const auto us =
duration_cast<microseconds>(now.time_since_epoch()).count() % 1000000;
const auto hms = system_clock::to_time_t(now);
const std::tm *tm = std::localtime(&hms);
stream << "current time = " << std::put_time(tm, "%H:%M:%S") << '.'
<< std::setw(6) << std::setfill('0') << us;
return false;
}
// Converts decimal integer to binary string
template <typename T> std::string decimalToBinary(T n) {
const size_t length = 8 * sizeof(T);
std::string toRet;
toRet.resize(length);
for (size_t i = 0; i < length; ++i) {
const auto bit_at_index_i = static_cast<char>((n >> i) & 1);
toRet[length - 1 - i] = bit_at_index_i + '0';
}
return toRet;
}
template <typename T>
inline bool pretty_print(std::ostream &stream,
const print_formatted<T> &value) {
if (value.inner < 0) {
stream << "-";
}
stream << value.prefix();
// Print using setbase
if (value.base != 2) {
stream << std::setw(sizeof(T)) << std::setfill('0')
<< std::setbase(value.base) << std::uppercase;
if (value.inner >= 0) {
// The '+' sign makes sure that a uint_8 is printed as a number
stream << +value.inner;
} else {
using unsigned_type = typename std::make_unsigned<T>::type;
stream << +(static_cast<unsigned_type>(-(value.inner + 1)) + 1);
}
} else {
// Print for binary
if (value.inner >= 0) {
stream << decimalToBinary(value.inner);
} else {
using unsigned_type = typename std::make_unsigned<T>::type;
stream << decimalToBinary<unsigned_type>(
static_cast<unsigned_type>(-(value.inner + 1)) + 1);
}
}
return true;
}
template <typename T>
inline bool pretty_print(std::ostream &stream, const print_type<T> &) {
stream << type_name<T>();
stream << " [sizeof: " << sizeof(T) << " byte, ";
stream << "trivial: ";
if (std::is_trivial<T>::value) {
stream << "yes";
} else {
stream << "no";
}
stream << ", standard layout: ";
if (std::is_standard_layout<T>::value) {
stream << "yes";
} else {
stream << "no";
}
stream << "]";
return false;
}
template <typename Enum>
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
pretty_print(std::ostream &stream, Enum const &value) {
using UnderlyingType = typename std::underlying_type<Enum>::type;
stream << static_cast<UnderlyingType>(value);
return true;
}
inline bool pretty_print(std::ostream &stream, const std::string &value) {
stream << '"' << value << '"';
return true;
}
#if DBG_MACRO_CXX_STANDARD >= 17
inline bool pretty_print(std::ostream &stream, const std::string_view &value) {
stream << '"' << std::string(value) << '"';
return true;
}
#endif
template <typename T1, typename T2>
inline bool pretty_print(std::ostream &stream, const std::pair<T1, T2> &value) {
stream << "{";
pretty_print(stream, value.first);
stream << ", ";
pretty_print(stream, value.second);
stream << "}";
return true;
}
#if DBG_MACRO_CXX_STANDARD >= 17
template <typename T>
inline bool pretty_print(std::ostream &stream, const std::optional<T> &value) {
if (value) {
stream << '{';
pretty_print(stream, *value);
stream << '}';
} else {
stream << "nullopt";
}
return true;
}
template <typename... Ts>
inline bool pretty_print(std::ostream &stream,
const std::variant<Ts...> &value) {
stream << "{";
std::visit([&stream](auto &&arg) { pretty_print(stream, arg); }, value);
stream << "}";
return true;
}
#endif
template <typename Container>
inline typename std::enable_if<detail::is_container<const Container &>::value,
bool>::type
pretty_print(std::ostream &stream, const Container &value) {
stream << "{";
const size_t size = detail::size(value);
const size_t n = std::min(size_t{10}, size);
size_t i = 0;
using std::begin;
using std::end;
for (auto it = begin(value); it != end(value) && i < n; ++it, ++i) {
pretty_print(stream, *it);
if (i != n - 1) {
stream << ", ";
}
}
if (size > n) {
stream << ", ...";
stream << " size:" << size;
}
stream << "}";
return true;
}
template <typename T, typename... U> struct last {
using type = typename last<U...>::type;
};
template <typename T> struct last<T> { using type = T; };
template <typename... T> using last_t = typename last<T...>::type;
class DebugOutput {
public:
// Helper alias to avoid obscure type `const char* const*` in signature.
using expr_t = const char *;
DebugOutput(const char *filepath, int line, const char *function_name)
: m_use_colorized_output(isColorizedOutputEnabled()) {
std::string path = filepath;
const std::size_t path_length = path.length();
if (path_length > MAX_PATH_LENGTH) {
path = ".." +
path.substr(path_length - MAX_PATH_LENGTH, MAX_PATH_LENGTH);
}
std::stringstream ss;
ss << ansi(ANSI_DEBUG) << "[" << path << ":" << line << " ("
<< function_name << ")] " << ansi(ANSI_RESET);
m_location = ss.str();
}
template <typename... T>
auto print(std::initializer_list<expr_t> exprs,
std::initializer_list<std::string> types, T &&...values)
-> last_t<T...> {
if (exprs.size() != sizeof...(values)) {
std::cerr << m_location << ansi(ANSI_WARN)
<< "The number of arguments mismatch, please check "
"unprotected comma"
<< ansi(ANSI_RESET) << std::endl;
}
return print_impl(exprs.begin(), types.begin(),
std::forward<T>(values)...);
}
private:
template <typename T>
T &&print_impl(const expr_t *expr, const std::string *type, T &&value) {
const T &ref = value;
std::stringstream stream_value;
const bool print_expr_and_type = pretty_print(stream_value, ref);
std::stringstream output;
output << m_location;
if (print_expr_and_type) {
output << ansi(ANSI_EXPRESSION) << *expr << ansi(ANSI_RESET)
<< " = ";
}
output << ansi(ANSI_VALUE) << stream_value.str() << ansi(ANSI_RESET);
if (print_expr_and_type) {
output << " (" << ansi(ANSI_TYPE) << *type << ansi(ANSI_RESET)
<< ")";
}
output << std::endl;
std::cerr << output.str();
return std::forward<T>(value);
}
template <typename T, typename... U>
auto print_impl(const expr_t *exprs, const std::string *types, T &&value,
U &&...rest) -> last_t<T, U...> {
print_impl(exprs, types, std::forward<T>(value));
return print_impl(exprs + 1, types + 1, std::forward<U>(rest)...);
}
const char *ansi(const char *code) const {
if (m_use_colorized_output) {
return code;
} else {
return ANSI_EMPTY;
}
}
const bool m_use_colorized_output;
std::string m_location;
static constexpr std::size_t MAX_PATH_LENGTH = 20;
static constexpr const char *const ANSI_EMPTY = "";
static constexpr const char *const ANSI_DEBUG = "\x1b[02m";
static constexpr const char *const ANSI_WARN = "\x1b[33m";
static constexpr const char *const ANSI_EXPRESSION = "\x1b[36m";
static constexpr const char *const ANSI_VALUE = "\x1b[01m";
static constexpr const char *const ANSI_TYPE = "\x1b[32m";
static constexpr const char *const ANSI_RESET = "\x1b[0m";
};
// Identity function to suppress "-Wunused-value" warnings in DBG_MACRO_DISABLE
// mode
template <typename T> T &&identity(T &&t) { return std::forward<T>(t); }
template <typename T, typename... U>
auto identity(T &&, U &&...u) -> last_t<U...> {
return identity(std::forward<U>(u)...);
}
} // namespace dbg
#ifndef DBG_MACRO_DISABLE
// Force expanding argument with commas for MSVC, ref:
// https://stackoverflow.com/questions/35210637/macro-expansion-argument-with-commas
// Note that "args" should be a tuple with parentheses, such as "(e1, e2, ...)".
#define DBG_IDENTITY(x) x
#define DBG_CALL(fn, args) DBG_IDENTITY(fn args)
#define DBG_CAT_IMPL(_1, _2) _1##_2
#define DBG_CAT(_1, _2) DBG_CAT_IMPL(_1, _2)
#define DBG_16TH_IMPL(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
_14, _15, _16, ...) \
_16
#define DBG_16TH(args) DBG_CALL(DBG_16TH_IMPL, args)
#define DBG_NARG(...) \
DBG_16TH( \
(__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
// DBG_VARIADIC_CALL(fn, data, e1, e2, ...) => fn_N(data, (e1, e2, ...))
#define DBG_VARIADIC_CALL(fn, data, ...) \
DBG_CAT(fn##_, DBG_NARG(__VA_ARGS__))(data, (__VA_ARGS__))
// (e1, e2, e3, ...) => e1
#define DBG_HEAD_IMPL(_1, ...) _1
#define DBG_HEAD(args) DBG_CALL(DBG_HEAD_IMPL, args)
// (e1, e2, e3, ...) => (e2, e3, ...)
#define DBG_TAIL_IMPL(_1, ...) (__VA_ARGS__)
#define DBG_TAIL(args) DBG_CALL(DBG_TAIL_IMPL, args)
#define DBG_MAP_1(fn, args) DBG_CALL(fn, args)
#define DBG_MAP_2(fn, args) fn(DBG_HEAD(args)), DBG_MAP_1(fn, DBG_TAIL(args))
#define DBG_MAP_3(fn, args) fn(DBG_HEAD(args)), DBG_MAP_2(fn, DBG_TAIL(args))
#define DBG_MAP_4(fn, args) fn(DBG_HEAD(args)), DBG_MAP_3(fn, DBG_TAIL(args))
#define DBG_MAP_5(fn, args) fn(DBG_HEAD(args)), DBG_MAP_4(fn, DBG_TAIL(args))
#define DBG_MAP_6(fn, args) fn(DBG_HEAD(args)), DBG_MAP_5(fn, DBG_TAIL(args))
#define DBG_MAP_7(fn, args) fn(DBG_HEAD(args)), DBG_MAP_6(fn, DBG_TAIL(args))
#define DBG_MAP_8(fn, args) fn(DBG_HEAD(args)), DBG_MAP_7(fn, DBG_TAIL(args))
#define DBG_MAP_9(fn, args) fn(DBG_HEAD(args)), DBG_MAP_8(fn, DBG_TAIL(args))
#define DBG_MAP_10(fn, args) fn(DBG_HEAD(args)), DBG_MAP_9(fn, DBG_TAIL(args))
#define DBG_MAP_11(fn, args) fn(DBG_HEAD(args)), DBG_MAP_10(fn, DBG_TAIL(args))
#define DBG_MAP_12(fn, args) fn(DBG_HEAD(args)), DBG_MAP_11(fn, DBG_TAIL(args))
#define DBG_MAP_13(fn, args) fn(DBG_HEAD(args)), DBG_MAP_12(fn, DBG_TAIL(args))
#define DBG_MAP_14(fn, args) fn(DBG_HEAD(args)), DBG_MAP_13(fn, DBG_TAIL(args))
#define DBG_MAP_15(fn, args) fn(DBG_HEAD(args)), DBG_MAP_14(fn, DBG_TAIL(args))
#define DBG_MAP_16(fn, args) fn(DBG_HEAD(args)), DBG_MAP_15(fn, DBG_TAIL(args))
// DBG_MAP(fn, e1, e2, e3, ...) => fn(e1), fn(e2), fn(e3), ...
#define DBG_MAP(fn, ...) DBG_VARIADIC_CALL(DBG_MAP, fn, __VA_ARGS__)
#define DBG_STRINGIFY_IMPL(x) #x
#define DBG_STRINGIFY(x) DBG_STRINGIFY_IMPL(x)
#define DBG_TYPE_NAME(x) dbg::type_name<decltype(x)>()
#define dbg(...) \
dbg::DebugOutput(__FILE__, __LINE__, __func__) \
.print({DBG_MAP(DBG_STRINGIFY, __VA_ARGS__)}, \
{DBG_MAP(DBG_TYPE_NAME, __VA_ARGS__)}, __VA_ARGS__)
#else
#define dbg(...) dbg::identity(__VA_ARGS__)
#endif // DBG_MACRO_DISABLE
#endif // DBG_MACRO_DBG_H

156
include/nnet/derivator.h Normal file
View File

@ -0,0 +1,156 @@
#pragma once
#include "common.h"
#include "expr.h"
#include "iterator_table.h"
#include "routine.h"
#include <iostream>
#include <sstream>
#include <unordered_set>
namespace nnet {
class Formula {
public:
Expr root;
const int bfsDepth;
public:
Formula(Expr _root, int _bfsDepth) : root(_root), bfsDepth(_bfsDepth) {}
string toReadable() const;
friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
ios << expr.toReadable();
return ios;
}
bool isVariable() const { return as<VarNode>(root) != nullptr; }
};
class MultiFormulas {
public:
VecExpr roots;
const int bfsDepth;
public:
MultiFormulas(VecExpr roots, int _bfsDepth)
: roots(roots), bfsDepth(_bfsDepth) {}
// string toReadable() const;
// friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
// ios << expr.toReadable();
// return ios;
// }
};
class Derivator {
public:
enum class LogMode { Normal, DumpFristCandiate, NoLog };
enum class PassMode { Debug, Full };
private:
list<Formula> candidates;
const int maxDepth;
int nIteratorNames = 0;
int nTensorNames = 0;
vector<vector<int>> rulesOverall;
enum class Strategy { DFS, Rule, RuleAndDFS } searchStrategy;
LogMode logMode;
PassMode passMode;
bool enableEquivalenceCheck = false;
string logFnPrefix;
const bool enableHashPruning;
int searchedMaxDepth = 0;
RoutineType targetOp = RoutineType::NoneType;
map<int, vector<Var>> substituteRules;
vector<int> cntAppliedRules;
int cntRule3 = 0;
std::unordered_set<HashType> visited;
VecExpr intermediateStates;
vector<string> ruleStates, ruleMsgs;
int cntStates = 0; // the number of intermediate states
int searchState = 0; // search state in guided search
public:
Derivator(int maxDepth = 8, bool enableHashPruning = true,
LogMode mode = LogMode::NoLog,
PassMode passMode = PassMode::Debug);
void search(Formula &origin, int depth);
void ruleBasedDFS(Formula &origin, int depth, vector<int> _rules,
map<int, vector<Var>> _substituteRules = {},
bool searchAfterRules = false);
void guidedSearch(Formula &origin, int depth);
void print();
int getNumCandidates() const { return candidates.size(); }
const auto &getCandidates() const { return candidates; }
void appendCanddiate(const Tensor &tensor, int depth);
int getSearchedMaxDepth() const { return searchedMaxDepth; };
bool stageCombination(MultiFormulas &origin, int depth);
bool checkOOB(const RangeOp &rangeOp, bool halt = true);
string newTensorName();
Var getNewVar();
Expr mergeMemboundStages(VecExpr stages);
private:
void dfs(Formula &origin, int depth);
void ruleBasedDerivate(Formula &origin, int depth);
void rule1VariableSplit(Formula &origin, int depth, Expr &rCur);
void rule2VariableMerging(Formula &origin, int depth, Expr &rCur);
void rule3StageSplit(Formula &origin, int dfsDepth, Expr &rCur);
void rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
bool mergeStageWithCalc = false);
void rule6KenerlMatching(Formula &origin, int depth, Expr &rCur);
void rule7DLT(Formula &origin, int depth, Expr &rCur);
// Rule 8: guidedDLT
void rule8GuidedDLT(Formula &origin, int depth, Expr &rCur);
void rule9RangeMagnify(Formula &origin, int depth, Expr &rCur);
void rule90TwoStageElementWise(Formula &origin, int depth, Expr &rCur);
void rule91MergeStagesWithSum(Formula &origin, int depth, Expr &rCur);
/**
* @brief For searchState=2, wrap the RangeOp to add offset, if the boundary
* does not start from 0. Then match the inner offset RangeOp.
*/
void matchComputationKernel(Formula &origin, int depth, Expr &rcur);
/**
* @brief For searchState=3, the Formula must be a MemBound kernel?
*/
void matchMemBoundKernel(Formula &origin, int depth, Expr &rcur);
/**
* @brief Check the equivalence for exprs in intermediateStates.
*/
void checkDerivationEquivalence();
public:
void pushIntermediateState(const Expr &expr);
void pushRuleState(const string &state);
void pushRuleMsg(const string &state);
void popIntermediateState();
void popRuleState();
void popRuleMsg();
// void pushTransformInfo(const Expr &expr, const string &state,
// const string &msg);
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur);
RoutineType getTargetOp();
void setTargetOp(RoutineType _targetOp);
int getSearchState();
void setSearchState(int _searchState);
int getNumIntermediateStates();
void printStatistics();
void printIntermediateStates();
/**
* @brief Enable dumping the first results. Verification is enabled so it is
* slow.
*
* @param _logFnPrefix Prefix of output filename
*/
void setDumpFirstSuccess(const string &_logFnPrefix);
void setEquivalenceCheck();
PassMode getPassMode();
LogMode getLogMode();
};
} // namespace nnet

53
include/nnet/dlt.h Normal file
View File

@ -0,0 +1,53 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
namespace nnet {
// enum class DLTType { Split, Merge, Reorder };
struct DLTOperation {
// DLTType type;
virtual ~DLTOperation() {}
};
struct DLTSplit : DLTOperation {
int dim, factor;
DLTSplit(int _dim, int _factor) : dim(_dim), factor(_factor) {}
};
struct DLTMerge : DLTOperation {
int dim0, dim1;
DLTMerge(int _dim0, int _dim1) : dim0(_dim0), dim1(_dim1) {}
};
struct DLTReorder : DLTOperation {
vector<int> dims;
DLTReorder(vector<int> _dims) : dims(_dims) {}
};
class DLT {
vector<Ref<DLTOperation>> ops;
public:
/**
* @brief dim -> (dim/factor, factor)
*/
void split(int dim, int factor);
/**
* @brief Merge dim1 into dim0 -> (dim0, dim1)
*/
void merge(int dim0, int dim1);
/**
* @brief
*
* @param dims dims[new_dim]=old_dim
*/
void reorder(vector<int> dims);
optional<Expr> apply(const RangeOp &rangeOp, const Subscript &subscript,
string newTensorName);
private:
optional<pair<Expr, Expr>> splitIndex(Expr expr, int factor,
RangeOp rangeOp);
};
} // namespace nnet

416
include/nnet/expr.h Normal file
View File

@ -0,0 +1,416 @@
#pragma once
#include "common.h"
#include "ref.h"
#include <iostream>
#include <numeric>
#include <type_traits>
namespace nnet {
class ExprNode;
class VarNode;
class TensorNode;
class OperatorNode;
class RangeOpNode;
class SubscriptNode;
class BinaryOpNode;
class ConstantNode;
class FuncNode;
using Expr = Ref<ExprNode>;
using Var = Ref<VarNode>;
using Tensor = Ref<TensorNode>;
using Operator = Ref<OperatorNode>;
using RangeOp = Ref<RangeOpNode>;
using Subscript = Ref<SubscriptNode>;
using BinaryOp = Ref<BinaryOpNode>;
using Constant = Ref<ConstantNode>;
using Func = Ref<FuncNode>;
class RoutineNode;
using Routine = Ref<RoutineNode>;
enum class RoutineType {
NoneType = 100,
MatmulNodeType,
ConvNodeType,
G2bmmNodeType,
GbmmNodeType,
ElementWiseNodeType // unmatchable
};
constexpr inline int MatchableRoutineTypeCnt = 4;
constexpr inline int RoutineTypeCnt = MatchableRoutineTypeCnt + 1;
inline RoutineType idToRoutineType(int i) {
return static_cast<RoutineType>(i + 1 +
static_cast<int>(RoutineType::NoneType));
}
inline int routineTypeToId(const RoutineType &routineType) {
return static_cast<int>(routineType) -
static_cast<int>(RoutineType::NoneType) - 1;
}
using VecExpr = vector<Expr>;
// common data structure
using Iterator = Var; // RE: remove this alias
template <typename T, typename U> using PtrMap = std::map<T, U, ptr_less<T>>;
template <typename T, typename U>
// When keys are pointers, compare keys according to its value instead of
// address Specially, the name of Var are compared due to the overload of op=
// and hash.
using PtrUmap = std::unordered_map<T, U, ptr_hash<T>, ptr_equal<T>>;
template <typename T>
using PtrUset = std::unordered_set<T, ptr_hash<T>, ptr_equal<T>>;
using Appearance = PtrMap<Var, vector<pair<Tensor, int>>>;
using StrideTable =
PtrMap<Var, vector<tuple<TensorNode *, int, int>>>; // Tensor, dim, stride
// AST node opeartor
bool operator==(const Var &lhs, const string &rhs);
bool operator==(const string &lhs, const Var &rhs);
Expr operator+(const Expr &lhs, const Expr &rhs);
BinaryOp operator-(const Expr &lhs, const Expr &rhs);
BinaryOp operator*(const Expr &lhs, const Expr &rhs);
BinaryOp operator/(const Expr &lhs, const Expr &rhs);
BinaryOp operator%(const Expr &lhs, const Expr &rhs);
Expr operator+(const Expr &lhs, const int &rhs);
Expr operator+(const int &lhs, const Expr &rhs);
Expr operator-(const Expr &lhs, const int &rhs);
Expr operator-(const int &lhs, const Expr &rhs);
Expr operator*(const Expr &lhs, const int &rhs);
Expr operator*(const int &lhs, const Expr &rhs);
Expr operator%(const Expr &lhs, const int rhs);
Expr operator/(const Expr &lhs, const int rhs);
string serializeVec(vector<Expr> v);
string serializeVec(vector<Var> v);
template <typename T> inline string serializeVec(vector<T> v) {
if (v.empty())
return "[]";
return "[" +
std::accumulate(
v.begin() + 1, v.end(), to_string(v[0]),
[](const string &a, int b) { return a + ',' + to_string(b); }) +
"]";
}
// For RTTI and visitor pattern
enum class NodeType {
ConstantNodeType,
BinaryOpNodeType,
RangeOpNodeType,
SubscriptNodeType,
TensorNodeType,
VarNodeType,
FuncNodeType
};
enum class FuncType { Relu, Tanh };
#define DEFINE_GETTYPE(CLASS) \
NodeType getType() const override { return NodeType::CLASS##Type; }
class ExprNode {
public:
virtual ~ExprNode() {}
ExprNode &operator=(const ExprNode &rhs) = delete;
virtual HashType hash() const = 0; // RE: remove?
virtual string toReadable() const = 0;
friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr);
virtual NodeType getType() const = 0;
};
class VarNode : public ExprNode {
std::string name;
public:
VarNode(std::string _name) : name(_name){};
virtual ~VarNode() {}
DEFINE_GETTYPE(VarNode);
const std::string &getName() const { return name; }
HashType hash() const override { return genhash(name); };
string toReadable() const override { return name; };
bool equal(const Var &rhs) const { return name == rhs->getName(); }
bool neq(const Var &rhs) const { return !equal(rhs); }
bool less(const Var &rhs) const { return name < rhs->getName(); }
bool equal(const string &rhs) const { return name == rhs; }
bool operator==(const VarNode &rhs) const { return name == rhs.getName(); }
bool operator<(const VarNode &rhs) const { return name < rhs.getName(); }
};
enum class TensorType { Input, Weight, Intermediate };
class TensorNode : public ExprNode {
string name;
vector<int> shape, paddings;
TensorType type;
Routine source; // if NO source, then this is a input/weight tensor
public:
TensorNode(string _name, vector<int> _shape, vector<int> _paddings = {},
Routine _source = nullptr);
virtual ~TensorNode() {}
DEFINE_GETTYPE(TensorNode);
bool operator==(const string &rhs) { return name == rhs; }
friend bool operator==(const string &lhs, const TensorNode &rhs) {
return lhs == rhs.name;
}
HashType hash() const override { return genhash(name); }
string toReadable() const override;
string toOutputShape() const;
const std::string &getName() const { return name; }
std::vector<int> &getPadding() { return paddings; }
int getPadding(int i) const { return paddings[i]; }
const vector<int> &getPaddings() const { return paddings; }
void setPadding(int i, int p) { paddings[i] = p; }
const vector<int> &getShape() const { return shape; }
int getShape(int i) const { return shape[i]; }
int64_t getSize() const;
int getDims() const { return shape.size(); }
const Routine &getSource() const { return source; }
int getData(const Ref<vector<int>> &data, const vector<int> &idx);
size_t getOffset(const vector<int> &idx);
};
enum class OpType { Range, Add, Mul, Div, Mod, Sub };
const char opSymbols[] = "#+*/%-";
class OperatorNode : public ExprNode {
protected:
const OpType opType;
VecExpr subExprs;
public:
OperatorNode(OpType _opType) : opType(_opType){};
OperatorNode(OpType _opType, VecExpr _subExprs)
: opType(_opType), subExprs(_subExprs){};
int getSubExprsNum() { return subExprs.size(); };
const VecExpr &getSubExprs() { return subExprs; }
const Expr &getSubExprs(int i) const { return subExprs[i]; }
OpType getOpType() const { return opType; };
void setOperands(int i, Expr e) { subExprs[i] = e; }
};
using Range = pair<int, int>;
using VarRangePair = pair<Var, Range>;
inline int getLength(const Range &range) { return range.second - range.first; }
struct IterationType {
enum { Loop, Sum };
constexpr static int NumIterationType = 2;
};
class RangeOpNode : public OperatorNode {
public:
enum { Summand, END_POS };
constexpr static int Loop = IterationType::Loop;
constexpr static int Sum = IterationType::Sum;
private:
vector<VarRangePair> vars[IterationType::NumIterationType];
vector<int> paddings;
public:
RangeOpNode(Expr _summand) : OperatorNode(OpType::Range, {_summand}){};
RangeOpNode(const vector<VarRangePair> &_loopIters,
const vector<VarRangePair> &_sumIters, Expr _summand,
const vector<int> &paddings)
: OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters},
paddings(paddings){};
DEFINE_GETTYPE(RangeOpNode);
virtual HashType hash() const override {
nnet_unimplemented_halt();
return 0;
};
string toReadable() const override;
const Expr &getSummand() const { return subExprs[Summand]; }
const vector<VarRangePair> &getVarRanges(int _index) const {
return vars[_index];
}
const vector<VarRangePair> &getLoopVarRanges() const {
return vars[IterationType::Loop];
}
const vector<VarRangePair> &getSumVarRanges() const {
return vars[IterationType::Sum];
}
int getNumOutputDims() const;
bool hasVar(int index, Var name) const;
bool hasLoopVar(Var name) const { return hasVar(Loop, name); }
bool hasSumVar(Var name) const { return hasVar(Sum, name); }
bool hasLoopVar(string name) const {
return hasVar(Loop, make_ref<VarNode>(name));
}
bool hasSumVar(string name) const {
return hasVar(Sum, make_ref<VarNode>(name));
}
int getVarIndex(int type, string name);
void setSummand(Expr e) { subExprs[Summand] = e; }
void setLoopIterator(const vector<VarRangePair> &vecExpr) {
vars[Loop] = vecExpr;
}
void setSumIterator(const vector<VarRangePair> &vecExpr) {
vars[Sum] = vecExpr;
}
void setIterator(const vector<VarRangePair> &loop,
const vector<VarRangePair> &sum) {
setLoopIterator(loop);
setSumIterator(sum);
}
const VarRangePair &getVarRange(int _index, int i) const {
return vars[_index][i];
}
const Var &getLoopVar(int i) const { return vars[Loop][i].first; }
Range getRange(const Var &var) const;
VarRangePair getVarRange(const Var &var) const;
bool hasPaddings() const;
int getPaddings(int dim) const;
vector<int> getPaddings() const;
void setPaddings(vector<int> _paddings);
void setVarRange(int _index, int i, VarRangePair pair) {
vars[_index][i] = pair;
}
int64_t getFlops() const;
int64_t getInputSize(const RangeOp &self) const;
int64_t getOutputSize() const;
vector<int> getOutputShape() const;
// Including paddings
vector<Range> getOutputRanges() const;
};
class BinaryOpNode : public OperatorNode {
enum { LHS, RHS, END_POS };
public:
BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs)
: OperatorNode(_opType, {_lhs, _rhs}){};
virtual ~BinaryOpNode() {}
DEFINE_GETTYPE(BinaryOpNode);
virtual HashType hash() const override {
return genhash((HashType)opType,
genhash(subExprs[LHS]->hash(), subExprs[RHS]->hash()));
};
virtual string toReadable() const override;
const Expr &getLhs() const { return getSubExprs(LHS); };
const Expr &getRhs() const { return getSubExprs(RHS); };
void setLhs(Expr e) { setOperands(LHS, e); };
void setRhs(Expr e) { setOperands(RHS, e); };
// If Var/constant, use this one
optional<pair<Var, int>> getModDivParameter() const;
// If (Var+constant)/constant, use this one
pair<Expr, int> getModDivExpr() const;
bool isSwapable() const;
};
class ConstantNode : public ExprNode {
int val;
public:
ConstantNode(int _val) : val(_val){};
ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){};
virtual ~ConstantNode() {}
DEFINE_GETTYPE(ConstantNode);
int getValue() const { return val; }
virtual HashType hash() const override { return genhash(val, 6214587); };
virtual string toReadable() const override {
string ret;
ret += std::to_string(val);
return ret;
};
};
class SubscriptNode : public ExprNode {
protected:
Expr indexed;
VecExpr subExprs;
public:
SubscriptNode(Expr _indexed, vector<Expr> _subExprs) : subExprs(_subExprs) {
setObject(_indexed);
};
DEFINE_GETTYPE(SubscriptNode);
virtual HashType hash() const override {
nnet_unimplemented_continue();
return -1;
};
virtual string toReadable() const override;
size_t getDims() const { return subExprs.size(); }
const VecExpr &getIndex() const { return subExprs; }
const Expr &getIndex(size_t i) const { return subExprs[i]; }
void setIndex(size_t i, Expr e) { subExprs[i] = e; }
Expr *getObjectPtr() { return &indexed; }
Expr getObject() const { return indexed; }
void setObject(Expr e);
bool isRangeOpSubscripted() const;
bool isTensorSubscripted() const { return !isRangeOpSubscripted(); }
// Get the ranges of objects including paddings
vector<Range> getObjectRangesWithPaddings() const;
vector<Range> getObjectRangesWithoutPaddings() const;
};
class FuncNode : public ExprNode {
protected:
Subscript object;
FuncType funcType;
public:
FuncNode(Expr object, FuncType funcType) : funcType(funcType) {
setObject(object);
}
DEFINE_GETTYPE(FuncNode);
virtual HashType hash() const override {
nnet_unimplemented_continue();
return -1;
};
virtual string toReadable() const override;
const Subscript &getObject() const { return object; }
void setObject(Expr e);
FuncType getFuncType() const { return funcType; }
};
// Wrappers for type deduction
Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts);
RangeOp makeRangeOperator(const vector<VarRangePair> &_loopIters,
const vector<VarRangePair> &_sumIters, Expr _summand,
const vector<int> &paddings = {});
Tensor makeTensor(const string &name, const vector<int> &shape,
const vector<int> &paddings = {},
const Routine &source = nullptr);
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
*_ = nullptr>
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
os << ((!a) ? string("nullptr") : a->toReadable());
return os;
}
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
*_ = nullptr>
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
os << ((!a) ? string("nullptr") : a->toReadable());
return os;
}
#undef DEFINE_GETTYPE
} // namespace nnet
namespace std {
template <> struct hash<nnet::VarNode &> {
size_t operator()(const nnet::VarNode &t) const {
return std::hash<string>()(t.getName());
}
};
} // namespace std

View File

@ -0,0 +1,234 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
namespace nnet {
using PatternTensorMap = vector<Tensor>;
using PatternIterRangeMap = PtrMap<Iterator, VarRangePair>;
enum class MismatchType {
// Search required (undetermined)
MoreVar,
LessVar,
StrideMismatch,
// guided DLT (determined)
DLMismatch,
OutputDLMismatch,
OutputDimismatch
};
struct Mismatch {
MismatchType type;
int bitmap; // Row ID of IT
PtrMap<Iterator, Iterator>
mappingIter_r; // For DLT mismatch, iters are mapped
Mismatch(MismatchType _type, int _bitmap) : type(_type), bitmap(_bitmap) {}
Mismatch(MismatchType _type, int _bitmap,
PtrMap<Iterator, Iterator> _mappingIter_r)
: type(_type), bitmap(_bitmap), mappingIter_r(_mappingIter_r) {}
};
class Pattern;
class IteratorTable {
protected:
// using Appearance = map<string, vector<pair<Tensor, int>>>;
// using StrideTable = map<TensorNode *, vector<tuple<string, int, int>>>;
// // Var, dim, stride
RangeOp rangeOp;
// To real tensor
// FIXME: redundent
Appearance appearance;
vector<Tensor> tensors; // original tensor sequence
vector<Subscript> subscripts; // original subscripts sequence
StrideTable strideTable; // TODO [Refactor]: rename strideTable
PatternIterRangeMap iterToRange;
// mapping
vector<int> tensorMap; // [index for tensors] -> tensorID in pattern
PtrMap<Iterator, Iterator> iterMap; // [expr iter] -> pattern iter
// final data
vector<vector<Iterator>> posTable; // [Tensor bitmap]=[Iterator]
vector<vector<vector<Iterator>>>
iterInTensorDim; // [tensorID][dimOfTensor]=[Iterator],
// stride in each dim may be add
vector<vector<PtrMap<Iterator, int>>>
strideInDim; // [tensorID][dimOfTensor][Iterator]=stride,
// stride in each dim may be add
PtrMap<Iterator, vector<int>> strideInTensor; // [Iterator][tensorID]=stride
// final data: auxiliary data
vector<int> tensorIDMap_r;
PatternTensorMap tensorMap_r;
PatternIterRangeMap iterToRange_r;
public:
virtual ~IteratorTable() {}
IteratorTable() {}
IteratorTable(const IteratorTable &) = delete;
[[nodiscard]] bool analyzeExpr(const RangeOp &rangeOp);
// mapTensors
void buildTable(const vector<int> &_tensorMap);
void buildTableWithDefaultMap();
/**
* @brief Check whether the expression match a pattern. If not, return the
* detailed reason for guided search.
*
* @param patternIT
* @return vector<int> mismatched IT rows/tensors for guided DLT.
*/
vector<Mismatch> matchPatternIT(const Pattern &patternIT);
void matchIterators();
int getNumInputs() const { return tensors.size(); }
int getNumTensors() const { return tensors.size() + 1; }
int getNumRows() const { return 1 << getNumTensors(); }
int getNumIterators() const { return strideTable.size(); }
// vector<Tensor> tensorMap_r(
// pattern.nInputs); // [pattern tensor ID] -> real tensor
// map<string, VarRangePair> iterToRange_r; // [pattern iter] -> iter &
// range
auto getTables() const {
return tuple(posTable, iterInTensorDim, strideInTensor);
}
const auto &getStrideInDim() const { return strideInDim; }
vector<vector<Iterator>> getIterInTensorDim(int tensorID) const {
return iterInTensorDim[tensorID];
}
const vector<Iterator> &getPosTable(int bitmap) const {
return posTable[bitmap];
}
pair<PatternTensorMap, PatternIterRangeMap> getReverseMap() const;
int getStridesInTensor(Iterator iter, int tensorID) const;
vector<int> getIterDimInTensor(int tensorID, const Iterator &iter) const;
Tensor getTensor(int tensorID) const { return tensorMap_r[tensorID]; }
Subscript getSubscript(int tensorID) const {
return subscripts[tensorIDMap_r[tensorID]];
}
Range getIterRange(const Iterator &iter) const {
return rangeOp->getRange(iter);
}
/**
* @brief Check strides of each iterators and there position in tensors.
* Since many-to-many iterators matching exist, we take this procudure as a
* seperate function to deal with different iterator mapping solution.
*
* @param patternIT
* @param mappingIter_r
* @return vector<Mismatch>
*/
vector<Mismatch>
matchPatternITCheckStrides(const Pattern &patternIT,
PtrMap<Iterator, Iterator> mappingIter_r);
RangeOp getRangeOp() const;
};
struct StrideConstraint {
int tensorID;
Var v0, v1;
enum class Constraint { SAME, PROPOTIONAL } type;
};
class Pattern : public IteratorTable {
vector<StrideConstraint> strideConstraints;
public:
virtual Expr
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const = 0;
/**
* @brief Check whether all indexes only are a iterator
*
* @param tensorID
*/
bool isAllUniqueAccess(int tensorID) const;
const auto &getStrideConstraints() const { return strideConstraints; };
int calcPadding(const Tensor &tensor, int dim, Range rangeH, Range rangeR,
int offset) const;
};
class MatmulPattern : public Pattern {
public:
static const Pattern &getMatmulPattern();
static pair<Expr, pair<Tensor, Tensor>> getExpr(bool transA, bool transB,
int b, int m, int n, int k);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
class ConvPattern : public Pattern {
private:
static const Var n, c, h, w, f, r, s;
public:
static const Pattern &getPattern();
static Expr getExpr(Tensor A, Tensor K, int n, int c, int h, int w, int f,
int r, int s);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
class ConvTransPattern : public Pattern {
private:
static const Var n, c, h, w, f, r, s;
public:
static const Pattern &getPattern() = delete;
static Expr getExpr(Tensor A, Tensor K, int N, int C, int H, int W, int F,
int R, int S);
Expr
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override {
nnet_unimplemented_halt();
return nullptr;
};
};
class Sg2bmmPattern : public Pattern {
private:
static const Var b, m, w, k;
public:
static const Pattern &getPattern();
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int K,
int W, int D);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
class LongformerGBMMPattern : public Pattern {
private:
static const Var b, m, w, n;
public:
static const Pattern &getPattern();
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int W,
int K, int dilation);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
const Pattern &getPattern(RoutineType targetOp);
string getPatternName(RoutineType targetOp);
} // namespace nnet

57
include/nnet/nmutator.h Normal file
View File

@ -0,0 +1,57 @@
#pragma once
#include "core/mutator.h"
#include "nnet/expr.h"
#ifdef ABC
namespace infini {
class NMutator : public Mutator {
private:
// Suffix -N: NNet objects.
// Suffix -T: tpm objects.
// Map: NNet tensors -> tpm tensor.
std::map<std::string, Tensor> inputsNameNToTensorT;
enum class Mode { Normal, ToNaiveMembound, RuleBased } mode = Mode::Normal;
const double bandwidth = double(200) * 1024 * 1024 * 1024;
// If in RuleBased mode, use derivationRules in derivator
const std::vector<int> derivationRules;
public:
NMutator();
NMutator(const std::vector<int> &derivationRules);
~NMutator();
vector<Graph> run(const Graph &in_graph) override;
void setToNaiveMembound();
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
long long cntStates = 0;
long long cntCandidates = 0;
private:
int maxDepth = 8;
nnet::Expr opToExpression(Operator op);
void runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs);
/**
* @brief Test helper. Converting a single OP to Membound Op for
* corretness check.
*/
void runSingleOpToNaiveMembound(Graph in_graph,
std::vector<Graph> &out_graphs);
void runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs);
Graph expressionToGraph(nnet::Expr expr, Graph in_graph);
Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
double memboundTime(ssize_t cnt);
double memboundTime(const Shape &dims);
Graph transformTConv1x1(Operator op);
Graph transformTConv3x3(Operator op);
Graph transformDialtedConv(Operator op);
Graph transformConv1x1(Operator op);
Graph transformConv1xk(Operator op);
};
} // namespace infini
#endif

View File

@ -0,0 +1,38 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
namespace nnet {
class PermutationGenerator {
vector<vector<Iterator>> from, to;
vector<vector<size_t>> mapping;
public:
PermutationGenerator(vector<vector<Iterator>> _from,
vector<vector<Iterator>> _to);
bool next();
PtrMap<Iterator, Iterator> get() const;
};
template <typename T> class SubsetGenerator {
vector<T> elements;
int n, bitmap;
public:
SubsetGenerator(vector<T> elements, bool nonEmpty = 1)
: elements(elements), n(elements.size()), bitmap((nonEmpty > 0)) {
assert(n < 10);
};
bool next() { return ((++bitmap) < (1 << n) - 1); }
vector<T> get() const {
vector<T> ret;
for (int i = 0; i < n; ++i)
if (bitmap & (1 << i))
ret.emplace_back(elements[i]);
return ret;
}
};
} // namespace nnet

200
include/nnet/ref.h Normal file
View File

@ -0,0 +1,200 @@
#pragma once
#include "common.h"
#include <functional> // hash
#include <memory>
#include <type_traits>
namespace nnet {
template <typename T> struct is_ref;
/**
* Ref-counting pointer
*
* This class is thread-safe (For developers: concurrent accesses through
* different `std::shared_ptr`s to the same object is already thread-safe, while
* modifying the same `std::shared_ptr` is not. We never modify a `Ref`, so no
* locks are needed. See https://en.cppreference.com/w/cpp/memory/shared_ptr)
*/
template <class T> class Ref {
static_assert(is_ref<T>::value == false, "Ref should not be nested");
template <class U> friend class Ref;
std::shared_ptr<T> ptr_;
private:
public:
typedef T Object;
Ref() = default;
// Ref(std::nullptr_t) : Ref() {}
constexpr Ref(nullptr_t) noexcept : Ref() {}
Ref(const Ref &) = default;
Ref(Ref &&) = default;
Ref(std::shared_ptr<T> &&ptr) : ptr_(std::move(ptr)) {}
// Ref(const std::shared_ptr<T> &ptr) : ptr_(ptr) {}
// /// NO NOT USE THIS CONSTRUCTOR IN PUBLIC
// /// It is public because Pybind11 needs it
// Ref(T *ptr) : ptr_(ptr) {}
/**
* Shared with any compatible references
*/
template <class U,
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
Ref(const Ref<U> &other) : ptr_(std::static_pointer_cast<T>(other.ptr_)) {}
template <class U,
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
Ref &operator=(const Ref<U> &other) {
ptr_ = std::static_pointer_cast<T>(other.ptr_);
return *this;
}
Ref &operator=(const Ref &) = default;
Ref &operator=(Ref &&) = default;
template <class U> Ref<U> as() const {
Ref<U> ret;
ret.ptr_ = std::dynamic_pointer_cast<U>(ptr_);
return ret;
}
bool isValid() const { return ptr_ != nullptr; }
T &operator*() const {
nnet_assert(isValid(), "Empty pointer.");
return *ptr_;
}
T *operator->() const {
nnet_assert(isValid(), "Empty pointer.");
return ptr_.get();
}
T *get() const {
nnet_assert(isValid(), "Empty pointer.");
return ptr_.get();
}
friend inline bool operator==(const Ref &lhs, nullptr_t) {
return !lhs.isValid();
}
friend inline bool operator!=(const Ref &lhs, nullptr_t) {
return !(lhs == nullptr);
}
explicit operator bool() const { return ptr_ != nullptr; }
bool operator!() { return ptr_ == nullptr; }
void swap(Ref &__b) noexcept { ptr_.swap(__b.ptr_); }
};
template <class T, class U,
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
Ref<T> as(const Ref<U> &ref) {
return ref.template as<T>();
}
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
return Ref(make_shared<T>(std::forward<Params>(params)...));
}
// Comparator for Ref
template <typename T> struct is_ref : std::false_type {};
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
template <class Tuple, std::size_t index = 0, bool address_based>
typename std::enable_if_t<not is_ref<std::tuple_element_t<index, Tuple>>::value,
bool>
__ref_less(const Tuple &lhs, const Tuple &rhs) {
if constexpr (index >=
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
return std::get<index>(lhs) < std::get<index>(rhs);
else {
if (std::get<index>(lhs) != std::get<index>(rhs))
return std::get<index>(lhs) < std::get<index>(rhs);
else
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
}
}
template <class Tuple, std::size_t index = 0, bool address_based>
typename std::enable_if_t<is_ref<std::tuple_element_t<index, Tuple>>::value and
not address_based,
bool>
__ref_less(const Tuple &lhs, const Tuple &rhs) {
if constexpr (index >=
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
return std::get<index>(lhs)->less(std::get<index>(rhs));
else {
if (std::get<index>(lhs)->neq(std::get<index>(rhs)))
return std::get<index>(lhs)->less(std::get<index>(rhs));
else
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
}
}
template <class Tuple, std::size_t index = 0, bool address_based>
typename std::enable_if_t<
is_ref<std::tuple_element_t<index, Tuple>>::value and address_based, bool>
__ref_less(const Tuple &lhs, const Tuple &rhs) {
if constexpr (index >=
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
else {
if (std::get<index>(lhs).get() != std::get<index>(rhs).get())
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
else
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
}
}
template <class Tuple> bool ref_addr_less(const Tuple &lhs, const Tuple &rhs) {
return __ref_less<Tuple, 0, true>(lhs, rhs);
}
template <class Tuple> bool ref_value_less(const Tuple &lhs, const Tuple &rhs) {
return __ref_less<Tuple, 0, false>(lhs, rhs);
}
template <class Tuple> class RefAddrLess {
public:
bool operator()(const Tuple &a, const Tuple &b) const {
return ref_addr_less(a, b);
}
};
template <class Tuple> class RefValueLess {
public:
bool operator()(const Tuple &a, const Tuple &b) const {
return ref_value_less(a, b);
}
};
// make_ref_from_tuple
template <typename _Tp, typename _Tuple, size_t... _Idx>
constexpr Ref<_Tp> make_ref_from_tuple_impl(_Tuple &&__t,
std::index_sequence<_Idx...>) {
return make_ref<_Tp>(std::get<_Idx>(std::forward<_Tuple>(__t))...);
}
template <typename _Tp, typename _Tuple>
constexpr Ref<_Tp> make_ref_from_tuple(_Tuple &&__t) {
return make_ref_from_tuple_impl<_Tp>(
std::forward<_Tuple>(__t),
std::make_index_sequence<std::tuple_size_v<std::decay_t<_Tuple>>>{});
}
} // namespace nnet
// namespace std {
// template <class T> struct hash<ir::Ref<T>> {
// hash<T *> hash_;
// size_t operator()(const ir::Ref<T> &ref) const { return hash_(ref.get());
// }
// };
// } // namespace nnet

158
include/nnet/routine.h Normal file
View File

@ -0,0 +1,158 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
#include <sstream>
namespace nnet {
class RoutineNode;
class MatmulNode;
class ElementWiseNode;
using Routine = Ref<RoutineNode>;
using Matmul = Ref<MatmulNode>;
using ElementWise = Ref<ElementWiseNode>;
#define DEFINE_GETTYPE(CLASS) \
RoutineType getType() const override { return RoutineType::CLASS##Type; }
class RoutineNode {
protected:
Expr expr;
vector<Tensor> inputs;
public:
RoutineNode(Expr _expr, const vector<Tensor> &_inputs);
virtual string toReadable() const = 0;
const Expr &getExpr() const { return expr; }
const vector<Tensor> &getInputs() const { return inputs; }
virtual RoutineType getType() const = 0;
};
using MatmulArgs = tuple<int, // b
int, // m
int, // n
int, // k
bool, // transa
bool>; // transb
class MatmulNode : public RoutineNode {
int b, m, n, k;
bool transa, transb;
public:
MatmulNode(Expr _source, Tensor A, Tensor B, int _b, int _m, int _n, int _k,
bool _transa, bool _transb)
: RoutineNode(_source, {A, B}), b(_b), m(_m), n(_n), k(_k),
transa(_transa), transb(_transb) {}
DEFINE_GETTYPE(MatmulNode);
string toReadable() const override;
friend bool operator==(const MatmulNode &lhs, const MatmulNode &rhs);
MatmulArgs getArgs() { return tuple(b, m, n, k, transa, transb); }
};
using ConvArgs = tuple<int, // ph
int, // pw
int, // sh
int, // sw
int, // dh
int>; // dw
class ConvNode : public RoutineNode {
int ph, pw;
int sh, sw;
int dh, dw;
public:
ConvNode(Expr _source, Tensor A, Tensor K, int _ph, int _pw, int _sh = 1,
int _sw = 1, int _dh = 1, int _dw = 1)
: RoutineNode(_source, {A, K}), ph(_ph), pw(_pw), sh(_sh), sw(_sw),
dh(_dh), dw(_dw) {}
DEFINE_GETTYPE(ConvNode);
string toReadable() const override;
vector<int> getShape() const;
friend bool operator==(const ConvNode &lhs, const ConvNode &rhs);
ConvArgs getArgs() const;
};
class ElementWiseNode : public RoutineNode {
vector<int> outputShape;
public:
// _outputShape is redundent, but expr is still missing for DLT.
ElementWiseNode(Expr _source, vector<Tensor> _inputs,
vector<int> _outputShape)
: RoutineNode(_source, _inputs), outputShape(_outputShape) {}
DEFINE_GETTYPE(ElementWiseNode);
string toReadable() const override;
/**
* @brief Get the Estimated Time of mem bound OP.
*
* @return double Time in ms.
*/
double getEstimatedTime() const;
const vector<int> &getOutputShape() const { return outputShape; }
};
using G2bmmArgs = tuple<int, // b
int, // m
int, // w
int, // k
int>; // dilation
class G2bmmNode : public RoutineNode {
int b, m, w, k;
public:
G2bmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int k,
int d = 1)
: RoutineNode(source, {A, B}), b(b), m(m), w(w), k(k) {
assert(d == 1);
}
DEFINE_GETTYPE(G2bmmNode);
vector<int> getShape() const;
string toReadable() const override;
G2bmmArgs getArgs() const;
};
using GbmmArgs = tuple<int, // b
int, // m
int, // w
int, // n
int>; // dilation
class GbmmNode : public RoutineNode {
int b, m, w, n;
public:
GbmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int n,
int d = 1)
: RoutineNode(source, {A, B}), b(b), m(m), w(w), n(n) {
assert(d == 1);
}
DEFINE_GETTYPE(GbmmNode);
vector<int> getShape() const;
string toReadable() const override;
GbmmArgs getArgs() const;
};
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
return os;
}
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
return os;
}
} // namespace nnet

28
include/nnet/test.h Normal file
View File

@ -0,0 +1,28 @@
#pragma once
#include "common.h"
#include "derivator.h"
// clang-format off
#define CAT(A, B) A##B
#define SELECT(NAME, NUM) CAT(NAME##_, NUM)
#define GET_COUNT( _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ... ) COUNT
#define VA_SIZE( ... ) GET_COUNT( __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 )
#define VA_SELECT( NAME, ... ) SELECT( NAME, VA_SIZE(__VA_ARGS__) )(__VA_ARGS__)
#define _DEFVAR_1(name) auto name = make_ref<VarNode>(#name);
#define _DEFVAR_2(name, ...) _DEFVAR_1(name); _DEFVAR_1(__VA_ARGS__)
#define _DEFVAR_3(name, ...) _DEFVAR_1(name); _DEFVAR_2(__VA_ARGS__)
#define _DEFVAR_4(name, ...) _DEFVAR_1(name); _DEFVAR_3(__VA_ARGS__)
#define _DEFVAR_5(name, ...) _DEFVAR_1(name); _DEFVAR_4(__VA_ARGS__)
#define _DEFVAR_6(name, ...) _DEFVAR_1(name); _DEFVAR_5(__VA_ARGS__)
#define _DEFVAR_7(name, ...) _DEFVAR_1(name); _DEFVAR_6(__VA_ARGS__)
#define _DEFVAR_8(name, ...) _DEFVAR_1(name); _DEFVAR_7(__VA_ARGS__)
#define _DEFVAR_9(name, ...) _DEFVAR_1(name); _DEFVAR_8(__VA_ARGS__)
#define DEFINE_VAR(...) VA_SELECT(_DEFVAR, __VA_ARGS__)
// clang-format on
namespace nnet {
int matchExprResult(Derivator &derivator, string fn);
bool checkExprLogSame(string fnPrefix, int start, int end);
bool checkExprsEquvivalence(VecExpr exprs);
} // namespace nnet

128
include/nnet/visitor.h Normal file
View File

@ -0,0 +1,128 @@
#pragma once
#include "common.h"
#include "derivator.h"
#include "expr.h"
#include "routine.h"
#include <iostream>
#include <unordered_map>
namespace nnet {
template <typename FType> class Functor;
template <typename R, typename... Args> class Functor<R(Args...)> {
protected:
int verbose;
// FIXME: scope should be protected
public:
Functor(int _verobse = 0) : verbose(_verobse) {}
virtual ~Functor() = default;
#define DISPATCH(CLASS) \
case NodeType::CLASS##Type: \
return this->visit_(as<CLASS>(c), std::forward<Args>(args)...); \
break
#define FUNCTOR_DEFAULT \
{ return visitDefault(c, std::forward<Args>(args)...); }
virtual R dispatch(const Expr &c, Args... args) {
switch (c->getType()) {
DISPATCH(ConstantNode);
DISPATCH(BinaryOpNode);
DISPATCH(RangeOpNode);
DISPATCH(SubscriptNode);
DISPATCH(TensorNode);
DISPATCH(VarNode);
DISPATCH(FuncNode);
default:
nnet_assert(0, "Unknown type");
return R();
}
}
virtual R visit_(const Constant &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const BinaryOp &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const RangeOp &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Subscript &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Var &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT;
virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) {
dbg(*c);
nnet_assert(0, "Reach unimplemented visit function.");
return R();
};
[[deprecated("Define explicit methods for public access.")]] R
operator()(const Expr &e, Args... args) {
return dispatch(e, std::forward<Args>(args)...);
}
#undef FUNCTOR_DEFAULT
#undef DISPATCH
};
class Mutator : public Functor<Expr()> {
public:
Mutator(int _verobse = 0) : Functor(_verobse) {}
Expr visit_(const Constant &c) override;
Expr visit_(const BinaryOp &c) override;
Expr visit_(const RangeOp &c) override;
Expr visit_(const Subscript &c) override;
Expr visit_(const Var &c) override;
Expr visit_(const Tensor &c) override;
Expr visit_(const Func &c) override;
};
// template <typename... Args>
// class SingleStageVisitor : public Functor<void, Args...> {
// public:
// SingleStageVisitor(int _verobse = 0) : Functor<R, Args...>(_verobse) {}
// // R visit(const Constant &c) override ;
// R visit_(const BinaryOp &c) override {
// if (verbose)
// dbg(*c);
// this->dispatch(c->getLhs());
// this->dispatch(c->getRhs());
// }
// R visit_(const RangeOp &c) override {
// if (verbose)
// dbg(*c);
// this->dispatch(ret->getSummand());
// // NOT visit iterators and its ranges
// }
// R visit_(const Subscript &c) override {
// if (verbose)
// dbg(*c);
// this->dispatch(ret->getObject());
// for (size_t i = 0; i < ret->getDims(); ++i)
// this->dispatch(ret->getIndex(i));
// }
// // R visit(const Var &c) override;
// // R visit(const Tensor &c) override;
// };
// } // namespace nnet
// #include "nnet/Visitor/ReplaceVariable.h"
// #include "nnet/Visitor/StrideVisitor.h"
// namespace nnet {
class ExprTreeVisitor : public Functor<void(void)> {
private:
bool inBinary, inRange, inSub, inTensor;
public:
ExprTreeVisitor(bool _inBinary = 1, bool _inRange = 1, bool _inSub = 1,
bool _inTensor = 1, int _verobse = 0)
: Functor(_verobse), inBinary(_inBinary), inRange(_inRange),
inSub(_inSub), inTensor(_inTensor) {}
void visit_(const Constant &c) override;
void visit_(const BinaryOp &c) override;
void visit_(const RangeOp &c) override;
void visit_(const Subscript &c) override;
void visit_(const Var &c) override;
void visit_(const Tensor &c) override;
void visit_(const Func &c) override;
};
} // namespace nnet

View File

@ -0,0 +1,47 @@
#pragma once
#include "core/operator.h"
namespace infini {
class MatmulNode : public OperatorNode {
private:
// InfiniTensor assume a row-major tensor layout. transA=false means default
// dims, true means A should be transposed before matmul. This is in
// oppsite to column-major BLAS.
bool transA, transB;
ActType act;
// Auxiliary attributes
int b, m, n, k;
public:
MatmulNode(Tensor A, Tensor B, Tensor C, bool transA = false,
bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
std::string toString() const override;
vector<Shape> computeShape() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
Tensor getBias() const { return inputs[2]; }
ActType getAct() const { return act; }
bool getTransA() const { return transA; }
bool getTransB() const { return transB; }
int getB() const { return b; }
int getM() const { return m; }
int getN() const { return n; }
int getK() const { return k; }
HashType hashWithShape() const override;
OpPerfKey getOpPerfKey() const override;
private:
// Q: whether to check the output? Since we can build an Op first and then
// assure output.
// Fix 1: make shape inference a static method. But OpPerfKey are required.
bool checkValid(const TensorVec &inputs) const;
};
} // namespace infini

3
include/test.h Normal file
View File

@ -0,0 +1,3 @@
#pragma once
#include "core/common.h"
#include "gtest/gtest.h"

14
src/core/common.cc Normal file
View File

@ -0,0 +1,14 @@
#include "core/common.h"
#include <chrono>
#include <functional>
namespace infini {
double timeit(const std::function<void()> &func) {
auto start = std::chrono::high_resolution_clock::now();
func();
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::milli>(end - start).count();
}
} // namespace infini

20
src/core/graph.cc Normal file
View File

@ -0,0 +1,20 @@
#include "core/graph.h"
namespace infini {
void GraphNode::updateConnection() { IT_TODO_HALT(); }
string GraphNode::toString() const {
std::ostringstream oss;
oss << "GraphNode operators:\n";
for (const auto &op : ops)
oss << op << "\n";
return oss.str();
}
void GraphNode::dataMalloc() {
for (auto &tensor : tensors)
tensor->dataMalloc();
}
} // namespace infini

32
src/core/operator.cc Normal file
View File

@ -0,0 +1,32 @@
#include "core/operator.h"
namespace infini {
bool OperatorNode::isLinearOp() const {
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
}
bool OperatorNode::isElementWiseOp() const {
return enum_to_underlying(type) >= 200 && enum_to_underlying(type) < 300;
}
bool OperatorNode::isSplitOp() const { return type == OpType::Split; }
bool OperatorNode::isConcatOp() const { return type == OpType::Concat; }
bool OperatorNode::isComputeOp() const {
return type == OpType::Conv || type == OpType::Matmul ||
type == OpType::ConvTrans || type == OpType::G2BMM ||
type == OpType::GBMML;
}
bool OperatorNode::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorNode::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorNode::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose;
}
} // namespace infini

105
src/core/run_engine.cc Normal file
View File

@ -0,0 +1,105 @@
#include "core/run_enigne.h"
#include <chrono>
namespace infini {
void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
if (!tune && profiling)
IT_TODO_HALT();
const auto &kernelRegistry = KernelRegistry::getInstance();
auto perfEngine = PerfEngine::getInstance();
// Statistics
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
std::chrono::system_clock::time_point begin, end;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Int32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
// If no record and disable tuning, run with the default argument
if (!perfData && !tune) {
kernel->compute(op);
continue;
}
// TODO: The copy of record should be eliminated
PerfRecord record;
// Tune the kernel if there is no record
if (!perfData) {
record = kernel->tune(op);
perfEngine.setPerfData(perfKey, record);
} else
record = *perfData;
if (!profiling) {
kernel->compute(op, *perfData);
continue;
} else {
double t = timeit([&]() { kernel->compute(op, *perfData); });
op->print();
printf(" op_time %lf\n", t);
totalTime += t;
opTime[op->getOpType()] += t;
opCnt[op->getOpType()]++;
}
}
if (profiling)
printProfilingData(totalTime, opTime, opCnt);
}
double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto perfEngine = PerfEngine::getInstance();
// Statistics
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Int32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
PerfRecord record;
// Tune the kernel if there is no record
if (!perfData) {
record = kernel->tune(op);
perfEngine.setPerfData(perfKey, record);
} else
record = *perfData;
double t = record.time;
totalTime += t;
if (profiling) {
op->print();
printf(" op_time %lf\n", t);
opTime[op->getOpType()] += t;
opCnt[op->getOpType()]++;
}
}
if (profiling)
printProfilingData(totalTime, opTime, opCnt);
return totalTime;
}
void RunEngine::printProfilingData(double totalTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const {
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
for (const auto &[type, t] : opTime) {
printf("%11s %3d %7.3f %7.1f %7.3f\n",
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
t / totalTime * 100, t / opCnt.at(type));
}
}
} // namespace infini

90
src/core/tensor.cc Normal file
View File

@ -0,0 +1,90 @@
#include <core/tensor.h>
namespace infini {
TensorNode::TensorNode(const Shape &shape, DataType dtype)
: TensorBaseNode(shape.size(), dtype), shape(shape) {}
void TensorNode::dataMalloc() {
IT_ASSERT(data == nullptr);
// initialized to zero
data.reset(reinterpret_cast<VType *>(calloc(size(), sizeof(VType))));
}
VType TensorNode::getData(const Shape &pos) const {
return getData(getOffset(pos));
}
string TensorNode::toString() const {
return "TensorNode " + std::to_string(guid);
}
size_t TensorNode::getOffset(const Shape &pos) const {
auto nDim = pos.size();
IT_ASSERT(shape.size() == nDim);
if (pos.empty())
return 0;
for (size_t i = 0; i < nDim; ++i)
IT_ASSERT(pos[i] < 0 || pos[i] >= shape[i]);
size_t idx = pos[0];
size_t dm = 0;
while (++dm < nDim)
idx = idx * shape[dm] + pos[dm];
return idx;
}
size_t TensorNode::size() const {
size_t ret = 1;
for (const auto &d : shape)
ret *= d;
return ret;
}
void TensorNode::copyData(VType *dptr) {
IT_ASSERT(data != nullptr);
size_t sz = size();
#pragma omp parallel for
for (size_t i = 0; i < sz; ++i) {
data[i] = dptr[i];
}
}
void TensorNode::printData() const {
IT_ASSERT(data != nullptr);
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
for (size_t j = 0; j < numDims; ++j) {
if (i % dimSzVec[j] == 0) {
std::cout << "[";
}
}
std::cout << data[i];
for (size_t j = 0; j < numDims; ++j) {
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
std::cout << "]";
}
}
if (i != size() - 1)
std::cout << ", ";
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
std::cout << std::endl;
}
}
bool TensorNode::equalData(const Tensor &rhs) const {
IT_ASSERT(data != nullptr);
IT_ASSERT(rhs->data != nullptr);
if (shape != rhs->getDims())
return false;
size_t sz = size();
for (size_t i = 0; i < sz; ++i)
if (data[i] != rhs->data[i])
return false;
return true;
}
}; // namespace infini

9
src/core/tensor_base.cc Normal file
View File

@ -0,0 +1,9 @@
#include <core/tensor_base.h>
namespace infini {
TensorBaseNode::TensorBaseNode(int dim, DataType dtype)
: dim(dim), dtype(dtype) {}
VType TensorBaseNode::getData(size_t offset) const { return data[offset]; }
}; // namespace infini

38
src/kerels/cpu/matmul.cc Normal file
View File

@ -0,0 +1,38 @@
#include "operators/matmul.h"
#include "core/kernel.h"
namespace infini {
template <typename T> class NaiveMatmul : public Kernel {
void compute(const Operator &_op, const PerfRecord &record) const override {
auto op = as<MatmulNode>(_op);
T *A = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get());
T *B = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get());
T *C = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get());
IT_ASSERT(op->getTransA() == false && op->getTransB() == false);
IT_ASSERT(op->getAct() == ActType::None);
IT_ASSERT(op->getB() == 1);
const int M = op->getM(), N = op->getN(), K = op->getK();
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
C[i * N + j] = 0;
for (int k = 0; k < K; k++) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
}
}
void compute(const Operator &op) const override { compute(op, {}); }
PerfRecord tune(const Operator &op) const override {
return PerfRecord{.time = timeit([this, &op]() { compute(op); })};
}
};
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32,
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32,
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
} // namespace infini

View File

@ -0,0 +1,25 @@
#include "nnet/Pass/MatchComputationKernel.h"
#include "nnet/Visitor/PatternMatcher.h"
namespace nnet {
// RE: is this duplicate with Rule6KenerlMatching?
void MatchComputationKernel::transform(Formula &origin, int depth, Expr &rCur) {
nnet_assert(derivator.getSearchState() == 2, __LINE__);
auto cur = as<RangeOpNode>(rCur);
// Build wrapper stages for enforce axis starts from 0
PatternMatcher patternMatcher(derivator, cur);
cur = patternMatcher.getOffsetCur();
auto matches = patternMatcher.matchWithPattern(
cur, getPattern(derivator.getTargetOp()));
matches = patternMatcher.applyWrapper(matches);
for (auto newCur : matches) {
derivator.setSearchState(3);
nextStep(origin, depth, rCur, newCur);
derivator.setSearchState(2);
}
}
} // namespace nnet

View File

@ -0,0 +1,23 @@
#include "nnet/Pass/MatchMemBoundKernel.h"
#include "nnet/Visitor/InputVisitor.h"
namespace nnet {
void MatchMemBoundKernel::transform(Formula &origin, int depth, Expr &rCur) {
// FIXME: Whether the Formula is a Membound OP should be checked.
nnet_assert(derivator.getSearchState() == 3, __LINE__);
nnet_assert(origin.root.get() == rCur.get(),
"Only match the entire formula as a Membound Op");
auto rangeOp = as<RangeOpNode>(origin.root);
const auto &inputs = InputVisitor().getInputs(rangeOp);
auto source =
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
auto tensor =
makeTensor(newTensorName(), rangeOp->getOutputShape(), {}, source);
// The original code directly appends candidate. But it seems should be done
// by the search.
// appendCanddiate(as<TensorNode>(tensor), depth);
nextStep(origin, depth, rCur, tensor);
}
} // namespace nnet

58
src/nnet/Pass/Pass.cc Normal file
View File

@ -0,0 +1,58 @@
#include "nnet/Pass/Pass.h"
#include "nnet/Visitor/CloneMutator.h"
namespace nnet {
Pass::Pass(Derivator &derivator, const string &passName)
: derivator(derivator), passName(passName),
enableLogging(derivator.getLogMode() != Derivator::LogMode::NoLog),
enableDebug(false) {}
Pass::~Pass() = default;
void Pass::setEnableLogging(bool value) { enableLogging = value; }
void Pass::setEnableDebug(bool value) { enableDebug = value; }
void Pass::run(Formula &origin, int dfsDepth, Expr &rCur) {
initialize(origin, rCur);
transform(origin, dfsDepth, rCur);
finalize();
}
void Pass::initialize(Formula &origin, const Expr &rCur) {}
void Pass::finalize() {}
Var Pass::getNewVar() { return derivator.getNewVar(); }
string Pass::newTensorName() { return derivator.newTensorName(); }
void Pass::nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
const string &ruleMsg) {
// push rule action description
if (enableLogging) {
rCur.swap(newCur);
derivator.pushIntermediateState(origin.root);
rCur.swap(newCur);
derivator.pushRuleState(passName);
derivator.pushRuleMsg(ruleMsg);
}
if (enableDebug) {
// In debug mode, do not recur but save the transformed state
transformations.emplace_back(CloneMutator().clone(newCur));
} else
derivator.nextStep(origin, depth, rCur, newCur);
// pop rule action description
if (enableLogging) {
derivator.popIntermediateState();
derivator.popRuleState();
derivator.popRuleMsg();
}
}
const VecExpr &Pass::getTransformations() { return transformations; }
} // namespace nnet

View File

@ -0,0 +1,134 @@
#include "nnet/Pass/Rule1VariableSplit.h"
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
void Rule1VariableSplit::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
vector<Replace> replaces = getSplitableVar(cur);
// for (const auto &replace : replaces)
// dbg(replace.oldIters, replace.newIters, replace.psis,
// replace.newVarRanges);
for (const auto &replace : replaces) {
auto replacedSummand = replaceIters(cur->getSummand(), replace);
if (!replacedSummand) {
// TODO: if a real getMergableExprs is implemented, this case should
// be an error. Since the expr should appear in the AST.
dbg("Warning: No replacment happens.");
continue;
}
auto inner =
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
// build the outerRange{innerRange}[indexForInner] to do DLT
Expr nextCur = nullptr;
if (replace.iteratorType == IterationType::Loop) {
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
nextCur = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
} else
nextCur = inner;
string msg = "====== END rule1 VariableSplit: ";
dbg(msg, replace.oldIters, replace.newIters, replace.phis,
replace.psis);
msg = replace.toReadable();
nextStep(origin, depth, rCur, nextCur, msg);
}
}
vector<Replace> Rule1VariableSplit::getSplitableVar(const RangeOp &rangeOp) {
vector<Replace> ret;
// Split strategy
vector<int> SumFactors, LoopFactors;
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
SumFactors = {3};
LoopFactors = {4};
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
SumFactors = {2, 3};
// LoopFactors = {3, 4};
LoopFactors = {4};
} else
nnet_unimplemented_halt();
// Split Sum variable
for (const int k : SumFactors) {
for (const auto &[var, range] : rangeOp->getSumVarRanges()) {
int len = range.second - range.first;
auto p1 = getNewVar(); // p1=i/k
auto p2 = getNewVar(); // p2=i%k
if (len > 10 || len <= k || len % k != 0)
continue;
Range range1, range2;
if (range.first < 0) {
nnet_unimplemented_halt();
// FIXME: this must be ERROR
range1.first = range.first / k;
range1.second = range1.first + len / k;
range2.first = -k / 2;
range2.second = range2.first + k;
} else if (range.first == 0) {
range1.first = 0;
range1.second = len / k;
range2.first = 0;
range2.second = k;
} else {
nnet_unimplemented_continue();
continue;
}
Replace replace{.iteratorType = IterationType::Sum,
.oldIters = {var},
.newIters = {p1, p2},
.phis = {},
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
.newVarRanges = {{p1, range1}, {p2, range2}}};
ret.emplace_back(replace);
}
}
for (const int k : LoopFactors) {
// Split Loop variable
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
const int len = range.second - range.first;
// Debug HACK for dilated SG2BMM
if (derivator.getPassMode() == Derivator::PassMode::Debug &&
!(var->getName() == "m" && len % k == 0))
continue;
// Illeagel conditions
if (range.second - range.first <= k ||
(range.second - range.first) % k != 0)
continue;
// Unsupport conditions
if (range.first != 0)
continue;
auto p1 = getNewVar(); // p1=i/k
auto p2 = getNewVar(); // p2=i%k
Range range1(0, len / k);
Range range2(0, k);
nnet_assert(range1.second > 0 && range2.second > 0,
"Empty loop dim");
Replace replace{.iteratorType = IterationType::Loop,
.oldIters = {var},
.newIters = {p1, p2},
.phis = {var / 4, var % 4},
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
.newVarRanges = {{p1, range1}, {p2, range2}}};
ret.emplace_back(replace);
}
}
return ret;
}
Expr Rule1VariableSplit::replaceIters(Expr cur, const Replace &replace) {
// TODO [feature]: support multiple replacements in one mutator
if (replace.oldIters.size() != 1) {
nnet_unimplemented_continue();
return nullptr;
}
auto replaceMutator =
ReplaceVariable(replace.oldIters.at(0), replace.psis.at(0));
auto ret = replaceMutator(cur);
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,186 @@
#include "nnet/Pass/Rule2VariableMerging.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
namespace nnet {
void Rule2VariableMerging::transform(Formula &origin, int depth, Expr &rCur) {
// Extract r and s
auto cur = as<RangeOpNode>(rCur);
vector<Replace> replaces = getMergableReplaces(cur, depth);
// dbg("Start rule2VariableMerging", depth, mergableExprs, *cur);
for (const auto &replace : replaces) {
if (replace.iteratorType != IterationType::Loop) {
nnet_unimplemented_continue();
continue;
}
// replace vars in summand
auto replacedSummand = ReplaceKit::replaceMultipleExprs(
cur->getSummand(), replace.oldIters, replace.psis, true);
// replace var in rangeOp
auto inner =
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
// If OOB happens, this transformation is skipped
if (CheckOOBVisitor().checkRangeOp(inner))
continue;
// build the outerRange{innerRange}[indexForInner] to do DLT
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
// next searching step
string msg = replace.toReadable();
nextStep(origin, depth, rCur, outer, msg);
}
}
vector<Replace> Rule2VariableMerging::getMergableReplaces(RangeOp rangeOp,
int depth) {
vector<Replace> ret;
IteratorTable exprIT;
if (!exprIT.analyzeExpr(rangeOp)) {
nnet_unimplemented_continue();
return ret;
}
exprIT.buildTableWithDefaultMap();
const auto &strideInAllDim = exprIT.getStrideInDim();
set<pair<Iterator, Iterator>, RefValueLess<pair<Iterator, Iterator>>>
checkedIterPairs{};
// strideInAllDim: [tensorID][dimOfTensor][Iterator]=stride
for (size_t tensorID = 0; tensorID < strideInAllDim.size(); ++tensorID) {
const auto &strideInDimsOfATensor = strideInAllDim[tensorID];
for (const PtrMap<Iterator, int> &strideInADim :
strideInDimsOfATensor) {
for (const auto &it1 : strideInADim) {
for (const auto &it2 : strideInADim) {
// Backdoor for rule-based search
if (substituteRules.count(depth)) {
if (substituteRules[depth].at(0)->neq(it1.first))
continue;
if (substituteRules[depth].at(1)->neq(it2.first))
continue;
}
if (!(it1.first->equal(it2.first) &&
it1.second == it2.second) &&
rangeOp->hasLoopVar(it1.first) &&
rangeOp->hasLoopVar(it2.first)) {
// 2 iters -> 2 iters
if (auto opt = getReplaceMappingTwoLoopIters(rangeOp,
it1, it2))
ret.emplace_back(*opt);
// 2 iters -> 1 iter
const auto iterPair = pair(it1.first, it2.first);
if (!checkedIterPairs.count(iterPair)) {
checkedIterPairs.insert(iterPair);
if (auto opt = getReplaceMergingTwoLoopIters(
rangeOp, it1, it2, exprIT, tensorID))
ret.emplace_back(*opt);
}
}
}
}
}
}
return ret;
}
optional<Replace> Rule2VariableMerging::getReplaceMergingTwoLoopIters(
const RangeOp &rangeOp, pair<Iterator, int> pairA,
pair<Iterator, int> pairB, const IteratorTable &exprIT, int tensorID) {
// 1*A + sb*B -> C
// A=C%sb, B=C/sb
// ax+by->z, a=1 or -1
// For a>0 and b>0 : x=z%b, y=z/b
auto x = pairA.first, y = pairB.first;
int a = pairA.second, b = pairB.second;
if (abs(a) != 1 || abs(a) * abs(b) <= 0)
return {};
if (a < 0 && b > 0) { // The only unhandled case
nnet_unimplemented_continue();
return {};
}
// negative substitution happens only if can be totally merged. So if the
// variable appears in another index, skip it.
if (a < 0 || b < 0) {
if (exprIT.getNumInputs() > 1) {
if (exprIT.getStridesInTensor(x, 1 - tensorID) != 0)
return {};
if (exprIT.getStridesInTensor(y, 1 - tensorID) != 0)
return {};
}
}
Range rangeX = rangeOp->getVarRange(x).second,
rangeY = rangeOp->getVarRange(y).second;
if (rangeX.first != 0 || rangeY.first != 0)
return {};
int lenX = rangeX.second - rangeX.first;
if (abs(b) != lenX)
return {};
auto z = getNewVar();
Range rangeExpr{0, 1}; // 1 is the open interval compensation
auto calcRangeExpr = [&rangeExpr](int stride, const Range &r) {
if (stride > 0) {
rangeExpr.first += stride * r.first;
rangeExpr.second += stride * (r.second - 1);
} else {
rangeExpr.first += stride * (r.second - 1);
rangeExpr.second += stride * r.first;
}
};
calcRangeExpr(a, rangeX);
calcRangeExpr(b, rangeY);
// build the phi/psi for index transformation
// phi: j_x=(i_x...), psi: i_x=(j_x...)
auto ret = optional<Replace>();
ret.emplace();
ret->iteratorType = IterationType::Loop;
ret->newIters = {z};
ret->oldIters = {x, y};
ret->phis = {a * x + b * y - rangeExpr.first};
// For b < 0, the psis are not an equavalent replace. Since it must be
// simplified (z/b and z%b will be merged), the only important thing is
// their strides should be mergable. To merge the strides, an extra minus
// are introduced if their stride is negative.
ret->psis = {a * (z % b) + a * rangeExpr.first, (b > 0 ? 1 : -1) * (z / b)};
ret->newVarRanges = {{z, {0, rangeExpr.second - rangeExpr.first}}};
return ret;
}
optional<Replace>
Rule2VariableMerging::getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
pair<Iterator, int> pairA,
pair<Iterator, int> pairB) {
// the first iterator is replaced, the second remains
auto i1 = pairA.first, i2 = pairB.first;
int sa = pairA.second, sb = pairB.second;
// TODO: can be relaxed to sb|sb
if (sa != 1 || sb == 0)
return {};
if (sb < 0) {
nnet_unimplemented_continue();
return {};
}
Range rangeA = rangeOp->getVarRange(i1).second;
Range rangeB = rangeOp->getVarRange(i2).second;
auto j1 = getNewVar(), j2 = getNewVar();
Range rangeJ1, rangeJ2 = rangeB;
assert(pairA.second == 1);
rangeJ1.first = rangeA.first + rangeB.first * sb;
rangeJ1.second = rangeA.second + (rangeB.second - 1) * sb;
// build the phi/psi for index transformation
// phi: j_x=(i_x...), psi: i_x=(j_x...)
auto ret = optional<Replace>();
ret.emplace();
ret->iteratorType = IterationType::Loop;
ret->newIters = {j1, j2};
ret->oldIters = {i1, i2};
ret->newVarRanges = {{j1, rangeJ1}, {j2, rangeJ2}};
ret->phis = {sa * i1 + sb * i2, i2};
ret->psis = {j1 - (sb / sa) * j2, j2};
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,82 @@
#include "nnet/Pass/Rule3StageSplit.h"
#include "nnet/permutation.h"
namespace nnet {
void Rule3StageSplit::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
vector<vector<Iterator>> splitSchemes = getSplitSummationIters(cur);
for (const auto &varSplit : splitSchemes) {
bool isSplittable = false;
for (const auto &splitted : varSplit)
if (cur->hasSumVar(splitted))
isSplittable = true;
assert(isSplittable);
const vector<VarRangePair> loopVars = cur->getLoopVarRanges(),
sumVars = cur->getSumVarRanges();
// move iterators from Sigma to Loop
vector<VarRangePair> innerLoopVars, innerSumVars, outerSumVars;
VecExpr indexForInner;
for (const auto &kv : sumVars) {
bool isSplitted = false;
for (const auto &iter : varSplit)
if (iter == kv.first->getName())
isSplitted = true;
if (isSplitted) {
innerLoopVars.emplace_back(kv);
outerSumVars.emplace_back(kv);
} else
innerSumVars.emplace_back(kv);
}
innerLoopVars.insert(innerLoopVars.end(), loopVars.begin(),
loopVars.end());
for (const auto &[var, _] : innerLoopVars)
indexForInner.emplace_back(var);
// if no sum iterator, the stage is redundant
assert(!innerSumVars.empty());
auto inner =
makeRangeOperator(innerLoopVars, innerSumVars, cur->getSummand());
auto subscriptedInner = make_ref<SubscriptNode>(inner, indexForInner);
auto outer = makeRangeOperator(cur->getLoopVarRanges(), outerSumVars,
subscriptedInner);
outer->setPaddings(cur->getPaddings());
// next searching step
string msg = "Separate sum iters: " + serializeVec(varSplit);
nextStep(origin, depth, rCur, outer, msg);
}
}
vector<vector<Iterator>>
Rule3StageSplit::getSplitSummationIters(RangeOp rangeOp) {
// set<string> varSplit = {"r", "s", "i3", "i13"};
vector<vector<Iterator>> ret;
// Rule-based Hint
// vector<vector<Iterator>> heuristics = {{"r", "s"}, {"i3", "i13"}};
// for (const auto &iterSet : heuristics) {
// bool notExist = false;
// for (const auto &iter : iterSet)
// if (!rangeOp->hasSumVar(iter))
// notExist = true;
// if (!notExist)
// ret.emplace_back(iterSet);
// }
// if (!rulesOverall.empty())
// return ret;
vector<Iterator> sumIters;
for (const auto &[iter, range] : rangeOp->getSumVarRanges())
sumIters.emplace_back(iter);
if (sumIters.size() <= 1)
return ret;
SubsetGenerator gen(sumIters);
do {
ret.emplace_back(gen.get());
} while (gen.next());
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,91 @@
#include "nnet/Pass/Rule4StageMerging.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
void Rule4StageMerging::transform(Formula &origin, int depth, Expr &rCur) {
success = rule4StageMerging(origin, depth, rCur, mergeStageWithCalc);
}
bool Rule4StageMerging::isSuccessful() { return success; }
void Rule4StageMerging::setMergeStageWithCalc(bool value) {
mergeStageWithCalc = value;
}
bool Rule4StageMerging::rule4StageMerging(Formula &origin, int depth,
Expr &rCur, bool mergeStageWithCalc) {
auto rangeOp0 = as<RangeOpNode>(rCur);
const Subscript &sub0 = as<SubscriptNode>(rangeOp0->getSummand());
if (!sub0)
return false;
const auto &rangeOp1 = as<RangeOpNode>(sub0->getObject());
if (!rangeOp1)
return false;
const auto &sub1 = as<SubscriptNode>(rangeOp1->getSummand());
if (!sub1)
return false;
// merge stage with calculation only when mergeStageWithCalc=true
if (!mergeStageWithCalc && !rangeOp1->getSumVarRanges().empty())
return false;
// Only propogate paddings in perfect nested dimension
if (rangeOp1->hasPaddings()) {
auto oldTensor = as<TensorNode>(sub1->getObject());
if (!oldTensor) {
nnet_unimplemented_continue();
return 0;
}
}
// repalce variables: iters of rangeOp1 repalced by indexes of sub0
map<string, pair<Expr, Expr>> varMapping;
assert(sub0->getDims() == rangeOp1->getLoopVarRanges().size());
for (size_t i = 0; i < sub0->getDims(); ++i) {
varMapping[rangeOp1->getLoopVar(i)->getName()] =
pair(rangeOp1->getLoopVar(i), sub0->getIndex(i));
}
ReplaceVariable replaceVariable{varMapping};
auto merged = make_ref<RangeOpNode>(*rangeOp0);
merged->setSummand(replaceVariable(sub1));
// a naive approach to propogate paddings
if (rangeOp1->hasPaddings()) {
auto oldTensor = as<TensorNode>(sub1->getObject());
auto newTensor = make_ref<TensorNode>(*oldTensor);
for (int i = 0; i < rangeOp1->getNumOutputDims(); ++i) {
if (rangeOp1->getPaddings(i) == 0)
continue;
auto loopVar = rangeOp1->getLoopVar(i);
// FIXME: in fact this var should not appear in other index as well,
// which may result in OOB
bool findSingleVarAsIndex = false;
for (size_t subIndexID = 0; subIndexID < sub1->getDims();
++subIndexID) {
auto index = sub1->getIndex(subIndexID);
if (auto indexVar = as<VarNode>(index);
indexVar && (indexVar->equal(loopVar))) {
newTensor->setPadding(subIndexID,
newTensor->getPadding(subIndexID) +
rangeOp1->getPaddings(i));
findSingleVarAsIndex = true;
}
}
if (!findSingleVarAsIndex) {
nnet_unimplemented_continue();
return false;
}
}
merged = as<RangeOpNode>(
ReplaceNodeMutator().replace(merged, oldTensor.get(), newTensor));
assert(merged != nullptr);
}
// Merge inner stage sums
if (!rangeOp1->getSumVarRanges().empty())
merged->setSumIterator(rangeOp1->getSumVarRanges());
// next searching step
// if mergeStageWithCalc, depth counts for invocation in rule-based search
nextStep(origin, (mergeStageWithCalc) ? depth : depth - 1, rCur, merged);
return true;
}
} // namespace nnet

View File

@ -0,0 +1,72 @@
#include "nnet/Pass/Rule5RangeRelaxation.h"
#include "nnet/Visitor/RangeRelaxFunctor.h"
namespace nnet {
void Rule5RangeRelaxation::transform(Formula &origin, int depth, Expr &rCur) {
rule5RangeRelaxation(origin, depth, rCur);
}
Expr Rule5RangeRelaxation::rule5RangeRelaxation(Formula &origin, int depth,
Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
if (cur->hasPaddings()) {
// string msg = "====== END rule5RangeRelaxation: Paddings exist \n";
// dbg(msg);
return nullptr;
}
// Infer meaningful calculation range
RangeRelaxFunctor rangeRexlaxtionFunctor{cur};
RangeMap rangeMap = rangeRexlaxtionFunctor(cur);
auto relaxedCur = make_ref<RangeOpNode>(*cur);
bool isRelaxed = false;
vector<int> paddings;
// check whether narrow the calculation range
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
const auto &[iter, iterRange] =
cur->getVarRange(IterationType::Loop, i);
if (auto it = rangeMap.find(iter); it != rangeMap.end()) {
// intersection of validRange and iterRange is necessary computation
// TODO: it is redundant with RangeRelaxFunctor::intersectRangeMaps.
// An independent Range class might be necessary.
const Range &validRange = it->second;
Range relaxedRange{max(iterRange.first, validRange.first),
min(iterRange.second, validRange.second)};
if (relaxedRange != iterRange) {
isRelaxed = true;
relaxedCur->setVarRange(IterationType::Loop, i,
{iter, relaxedRange});
paddings.emplace_back(
max(relaxedRange.first - iterRange.first,
iterRange.second - relaxedRange.second));
} else
paddings.emplace_back(0);
} else
paddings.emplace_back(0);
}
relaxedCur->setPaddings(paddings);
if (!isRelaxed) {
// string msg = "====== END rule5RangeRelaxation: Relaxation not
// found\n"; dbg(msg);
return nullptr;
}
// next searching step
string detailedMsg;
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
const auto &[v, a] = cur->getVarRange(IterationType::Loop, i);
const auto &[_, b] = relaxedCur->getVarRange(IterationType::Loop, i);
if (a != b) {
detailedMsg += v->getName();
detailedMsg +=
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
detailedMsg +=
" (" + to_string(b.first) + "," + to_string(b.second) + "),";
}
}
nextStep(origin, depth, rCur, relaxedCur, detailedMsg);
return relaxedCur;
}
} // namespace nnet

View File

@ -0,0 +1,57 @@
#include "nnet/Pass/Rule6KenerlMatching.h"
#include "nnet/Visitor/InputVisitor.h"
#include "nnet/Visitor/PatternMatcher.h"
namespace nnet {
void Rule6KenerlMatching::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
// Build wrapper stages for enforce axis starts from 0
PatternMatcher patternMatcher(derivator, cur);
cur = patternMatcher.getOffsetCur();
// Match matchable routines
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
auto targetOp = idToRoutineType(i);
// During guided search, only check the target OP
if (derivator.getTargetOp() != RoutineType::NoneType &&
derivator.getTargetOp() != targetOp)
continue;
auto replaces =
patternMatcher.matchWithPattern(cur, getPattern(targetOp));
replaces = patternMatcher.applyWrapper(replaces);
for (auto newCur : replaces)
nextStep(origin, depth, rCur, newCur);
}
{ // Match element-wise OP
auto replaces = matchElementWise(cur);
if (!replaces.empty())
dbg(rCur);
for (auto newCur : replaces)
nextStep(origin, depth, rCur, newCur);
}
}
VecExpr Rule6KenerlMatching::matchElementWise(const RangeOp &rangeOp) {
// If the stage is compute bound, then do not convert it.
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
int64_t inputSize = rangeOp->getInputSize(rangeOp);
if (double(flops) / (inputSize + outputSize) > 3)
return {};
vector<int> newShape;
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
if (range.first != 0) {
nnet_unimplemented_continue();
return {};
}
newShape.emplace_back(range.second - range.first);
}
const auto &inputs = InputVisitor().getInputs(rangeOp);
auto source =
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
return {newTensor};
}
} // namespace nnet

78
src/nnet/Pass/Rule7DLT.cc Normal file
View File

@ -0,0 +1,78 @@
#include "nnet/Pass/Rule7DLT.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
#include "nnet/dlt.h"
namespace nnet {
void Rule7DLT::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
if (!cur)
return;
auto op = as<BinaryOpNode>(cur->getSummand());
if (!op)
return;
auto subs = {op->getLhs(), op->getRhs()};
for (auto subExpr : subs) {
auto sub = as<SubscriptNode>(subExpr);
if (!sub)
continue;
auto tensor = as<TensorNode>(sub->getObject());
if (!tensor)
continue;
// // HACK for G2BMM
// if (tensor->getDims() != 3)
// continue;
for (const auto factor : getFactors()) {
for (int targetDim = 0; targetDim < tensor->getDims();
++targetDim) {
if (tensor->getShape(targetDim) % factor)
continue;
// Debug hint for G2BMM
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
if (tensor->getShape(targetDim) != 10000)
continue;
assert(targetDim == 1);
}
DLT dlt;
dlt.split(targetDim, factor);
vector<int> newOrder(tensor->getDims() + 1);
for (int i = 0; i < tensor->getDims() + 1; ++i)
newOrder[i] = i;
newOrder[targetDim]++;
newOrder[targetDim + 1]--;
dlt.reorder(newOrder);
dlt.merge(targetDim, targetDim + 1);
if (auto opt = dlt.apply(cur, sub, newTensorName())) {
Expr newSummand = ReplaceNodeMutator().replace(
cur->getSummand(), sub.get(), *opt);
auto newCur = buildDLTSingleRangeOp(cur, newSummand);
// next searching step
string msg = "====== END rule7DLT\n";
dbg(msg);
nextStep(origin, depth, rCur, newCur);
}
}
}
}
}
Expr Rule7DLT::buildDLTSingleRangeOp(const RangeOp &original,
const Expr &newSummand) {
auto rangeOp = make_ref<RangeOpNode>(*original);
rangeOp->setSummand(newSummand);
return rangeOp;
}
vector<int> Rule7DLT::getFactors() {
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
return {4};
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
return {3, 4};
} else {
nnet_unimplemented_halt();
return {};
}
}
} // namespace nnet

View File

@ -0,0 +1,317 @@
#include "nnet/Pass/Rule8GuidedDLT.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
namespace nnet {
static int bitCount(unsigned int n) {
int count = 0;
while (n != 0) {
n = n & (n - 1);
count++;
}
return count;
}
static int bitPosition(unsigned int n) {
assert(bitCount(n) == 1);
int ret = 0;
for (n >>= 1; n; n >>= 1)
++ret;
return ret;
}
void Rule8GuidedDLT::transform(Formula &origin, int depth, Expr &rCur) {
guidedDLT(origin, depth, rCur);
}
VecExpr Rule8GuidedDLT::guidedDLT(Formula &origin, int depth, Expr &rCur,
bool debug) {
string detailedMsg;
VecExpr ret;
auto cur = as<RangeOpNode>(rCur);
// check cur satisfies T1[A]*T2[B]
if (!statisfyGuidedDLT(cur))
return ret;
IteratorTable exprIT;
if (!exprIT.analyzeExpr(cur))
return ret;
exprIT.buildTableWithDefaultMap();
bool setTargetOpHere = false;
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
// if not correctly unset this variable
assert(setTargetOpHere == false);
// If the guide direction is set
if (derivator.getTargetOp() != RoutineType::NoneType &&
idToRoutineType(i) != derivator.getTargetOp())
continue;
// Warning: no continue befor unset the targetOp
if (derivator.getTargetOp() == RoutineType::NoneType) {
setTargetOpHere = true;
derivator.setTargetOp(idToRoutineType(i));
}
const Pattern &pattern = getPattern(derivator.getTargetOp());
auto mismatches = exprIT.matchPatternIT(pattern);
// Pruning less possible results
// std::cout << "mismatches= " << mismatches.size()
// << "; setTargetOpHere: " << setTargetOpHere << "; ";
// std::cout << "TargetOp = " <<
// static_cast<int>(derivator.getTargetOp())
// << "; mismatches : ";
// for (const auto i : mismatches)
// std::cout << static_cast<int>(i.type) << " ";
// std::cout << endl;
if (mismatches.size() == 0) {
derivator.setSearchState(2);
nextStep(origin, depth, rCur, rCur);
derivator.setSearchState(1);
}
if (mismatches.size() > 0 && mismatches.size() <= 2) {
for (const auto &mismatch : mismatches) {
Expr newCur;
if (mismatch.type == MismatchType::MoreVar) {
newCur = guidedDLTMoreVar2(cur, mismatch, exprIT, pattern);
detailedMsg += "guidedDLTMoreVar2 ";
} else if (mismatch.type == MismatchType::DLMismatch ||
mismatch.type == MismatchType::OutputDLMismatch) {
if (mismatches.size() > 1) {
nnet_unimplemented_continue();
break;
}
newCur =
guidedDLTDLMismatch(cur, mismatch, exprIT, pattern);
detailedMsg += "guidedDLTDLMismatch ";
}
// std::cout << "newCur= "
// << ((newCur == nullptr) ? "Nullptr"
// : newCur->toReadable())
// << endl;
if (!newCur)
continue;
if (debug)
ret.emplace_back(newCur);
// next searching step
detailedMsg = "Toward " +
getPatternName(derivator.getTargetOp()) + ". " +
detailedMsg;
nextStep(origin, depth, rCur, newCur, detailedMsg);
}
}
// Unset targetOp
if (setTargetOpHere) {
derivator.setTargetOp(RoutineType::NoneType);
setTargetOpHere = false;
}
}
return ret;
}
Expr Rule8GuidedDLT::guidedDLTDLMismatch(
const RangeOp &cur, const Mismatch &mismatch,
[[maybe_unused]] const IteratorTable &exprIT, const Pattern &pattern) {
assert(mismatch.type == MismatchType::DLMismatch ||
mismatch.type == MismatchType::OutputDLMismatch);
// Currently only deal with ouput DLT
if (mismatch.bitmap != pattern.getNumInputs()) {
nnet_unimplemented_continue();
return nullptr;
}
vector<VarRangePair> newVarRanges;
for (const auto &[var, _] : pattern.getRangeOp()->getLoopVarRanges()) {
const auto &iterInExpr = mismatch.mappingIter_r.at(var);
newVarRanges.emplace_back(cur->getVarRange(iterInExpr));
}
auto inner = make_ref<RangeOpNode>(*cur);
inner->setLoopIterator(newVarRanges);
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, {});
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
return outer;
}
bool Rule8GuidedDLT::statisfyGuidedDLT(RangeOp cur) const {
auto mul = as<BinaryOpNode>(cur->getSummand());
if (!mul)
return false;
if (mul->getOpType() != OpType::Mul)
return false;
return as<SubscriptNode>(mul->getLhs()) && as<SubscriptNode>(mul->getRhs());
}
Expr Rule8GuidedDLT::guidedDLTMoreVar2(const RangeOp &cur,
const Mismatch &mismatch,
const IteratorTable &exprIT,
const Pattern &pattern) {
int bitmap = mismatch.bitmap;
const auto &mergedItersDefaultOrder = exprIT.getPosTable(bitmap);
// Assure vars only appear in one input tensor
int bitmapOfInputs = bitmap & ((1 << exprIT.getNumInputs()) - 1);
if (bitCount(bitmapOfInputs) > 1)
return nullptr;
if (pattern.getPosTable(bitmap).size() != 1) {
nnet_unimplemented_continue();
return nullptr;
}
if (mergedItersDefaultOrder.size() < 1)
return nullptr;
int tensorID = bitPosition(bitmapOfInputs);
if (!checkElementsHaveOnlyOneAccessIteratorSet(exprIT, tensorID))
return nullptr;
vector<Var> oldVars; // i_1, ...
vector<Var> newVars; // j_1, ...
VecExpr psis; // i_1=\psi_1(j_1, ...)
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
vector<VarRangePair> newVarRanges;
auto originalTensor = exprIT.getTensor(tensorID);
auto originalSub = exprIT.getSubscript(tensorID);
vector<bool> mergedDims(originalTensor->getDims());
// Heuristic: merge iters according to their appearance positions
std::multimap<int, Var> sortedMergedIters;
for (const auto &iter : mergedItersDefaultOrder) {
vector<int> dims = exprIT.getIterDimInTensor(tensorID, iter);
assert(dims.size() == 1);
sortedMergedIters.emplace(dims[0], iter);
}
vector<Var> mergedIters; // decides the order of fused dims
for (const auto &[_, v] : sortedMergedIters)
mergedIters.emplace_back(v);
// Add the merged iterators
const auto newVar = getNewVar();
newVars.emplace_back(newVar);
int newRange = 1;
for (const auto &iter : mergedIters) {
oldVars.emplace_back(iter);
auto range = cur->getRange(iter);
newRange *= (range.second - range.first);
// if (range.first == 0)
// nnet_unimplemented_halt();
}
newVarRanges.emplace_back(newVar, Range{0, newRange});
// Add psis for each old iterator
int remainingRange = newRange;
Expr phi = nullptr;
for (const auto &iter : mergedIters) {
auto oldVar = iter;
auto range = cur->getRange(iter);
int len = (range.second - range.first);
remainingRange /= len;
Expr psi = newVar;
if (remainingRange > 1)
psi = psi / remainingRange;
if (newRange > remainingRange * len)
psi = psi % len;
int start = cur->getRange(iter).first;
if (start != 0)
psi = psi + start;
psis.emplace_back(psi);
phi = phi + remainingRange * (oldVar - start);
}
Replace replace{.iteratorType = IterationType::Loop,
.oldIters = oldVars,
.newIters = newVars,
.phis = VecExpr{phi},
.psis = psis,
.newVarRanges = newVarRanges};
// HACK: decide the rebuild data shape order
// TODO: get a partial iter mapping and permutate them?
vector<Var> tensorDimAxes{newVars};
vector<int> newShape;
for (const auto &[var, range] : newVarRanges)
newShape.emplace_back(range.second - range.first);
for (int row = 0; row < exprIT.getNumRows(); ++row) {
// Deal with other dimensions of the current tensor
if (row == bitmap || ((row & (1 << tensorID)) == 0))
continue;
using StrideIter = tuple<int, int, Iterator>;
vector<StrideIter> strideIters;
for (size_t i = 0; i < exprIT.getPosTable(row).size(); ++i) {
const auto &iter = exprIT.getPosTable(row)[i];
const Range range = cur->getRange(iter);
const int len = range.second - range.first;
// HACK Sort according to original stride. (keep original order)
strideIters.emplace_back(-exprIT.getStridesInTensor(iter, tensorID),
len, iter);
// // HACK for conv
// if (iter == "n")
// strideIters.emplace_back(2, len, iter);
// else if (iter == "c")
// strideIters.emplace_back(1, len, iter);
// else
// strideIters.emplace_back(0, len, iter);
}
// HACK: Assure the order of iterators
std::sort(strideIters.begin(), strideIters.end(),
ref_value_less<StrideIter>);
for (const auto &[_, len, oldIter] : strideIters) {
const auto &oldVar = oldIter;
tensorDimAxes.emplace_back(oldVar);
newShape.emplace_back(len);
}
}
// build DLT source
const auto sourceExpr =
buildGuidedDLTSource(originalSub, replace, tensorDimAxes, newShape);
const auto sourceRoutine = make_ref<ElementWiseNode>(
sourceExpr, vector<Tensor>{originalTensor}, newShape);
// build stage connections
const auto newTensor =
makeTensor(newTensorName(), newShape, {}, sourceRoutine);
const auto &newSub = makeSubscript(
newTensor, VecExpr(tensorDimAxes.begin(), tensorDimAxes.end()));
// TODO [1124]: get variable mapping and reorder L according to it
// dbg(cur, originalSub, newSub, newVarRanges, replace.toReadable(),
// tensorDimAxes, newShape);
// Replace the entire subscript(A[xxxxx,xxx]) in the summand
Expr newSummand = ReplaceNodeMutator().replace(cur->getSummand(),
originalSub.get(), newSub);
auto inner = ReplaceKit::replaceRangeOpIterator(cur, replace, newSummand);
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
return outer;
}
bool Rule8GuidedDLT::checkElementsHaveOnlyOneAccessIteratorSet(
const IteratorTable &exprIT, int tensorID) {
const auto &strideInDim = exprIT.getStrideInDim();
for (const auto &strideForOneDim : strideInDim[tensorID]) {
vector<pair<int, int>> strideLengthPairs;
for (const auto &[iter, s] : strideForOneDim) {
const auto &range = exprIT.getRangeOp()->getRange(iter);
strideLengthPairs.emplace_back(s, range.second - range.first);
}
std::sort(strideLengthPairs.begin(), strideLengthPairs.end());
for (size_t i = 0; i < strideLengthPairs.size() - 1; ++i) {
const auto &[stride, length] = strideLengthPairs[i];
if (stride * length > strideLengthPairs[i + 1].first)
return false;
}
}
return true;
}
Expr Rule8GuidedDLT::buildGuidedDLTSource(const Subscript &originalSub,
Replace replace,
vector<Var> tensorDimAxes,
vector<int> newShape) {
Expr newSub = ReplaceKit::replaceMultipleExprs(
originalSub, replace.oldIters, replace.psis, true);
vector<VarRangePair> loopVarRangePairs;
for (size_t i = 0; i < tensorDimAxes.size(); ++i)
loopVarRangePairs.emplace_back(tensorDimAxes[i], pair(0, newShape[i]));
return makeRangeOperator(loopVarRangePairs, {}, newSub);
}
} // namespace nnet

View File

@ -0,0 +1,54 @@
#include "nnet/Pass/Rule90TwoStageElementWise.h"
#include "nnet/Visitor/InputVisitor.h"
namespace nnet {
void Rule90TwoStageElementWise::transform(Formula &origin, int depth,
Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
{ // Match element-wise OP
auto replaces = matchTwoStageElementWise(cur);
// if (!replaces.empty())
// dbg(rCur);
// dbg(replaces);
for (auto newCur : replaces)
nextStep(origin, depth, rCur, newCur);
}
}
VecExpr
Rule90TwoStageElementWise::matchTwoStageElementWise(const RangeOp &rangeOp) {
// If the stage is compute bound, then do not convert it.
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
int64_t inputSize = rangeOp->getInputSize(rangeOp);
if (double(flops) / (inputSize + outputSize) > 3)
return {};
auto outerSub = as<SubscriptNode>(rangeOp->getSummand());
if (!outerSub)
return {};
auto innerRangeOp = as<RangeOpNode>(outerSub->getObject());
if (!innerRangeOp)
return {};
auto innerSub = as<SubscriptNode>(innerRangeOp->getSummand());
if (!innerSub)
return {};
auto innerTensor = as<TensorNode>(innerSub->getObject());
if (!innerTensor)
return {};
vector<int> newShape;
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
if (range.first != 0) {
nnet_unimplemented_continue();
return {};
}
newShape.emplace_back(range.second - range.first);
}
const auto &inputs = InputVisitor().getInputs(rangeOp);
auto source =
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
return {newTensor};
}
} // namespace nnet

View File

@ -0,0 +1,11 @@
#include "nnet/Pass/Rule91MergeStagesWithSum.h"
#include "nnet/Pass/Rule4StageMerging.h"
namespace nnet {
void Rule91MergeStagesWithSum::transform(Formula &origin, int depth,
Expr &rCur) {
Rule4StageMerging(derivator).rule4StageMerging(origin, depth, rCur, true);
}
} // namespace nnet

View File

@ -0,0 +1,45 @@
#include "nnet/Pass/Rule9RangeMagnify.h"
#include "nnet/Visitor/RangeMagnifyVisitor.h"
namespace nnet {
void Rule9RangeMagnify::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
if (cur->hasPaddings()) {
// string msg = "====== END rule9RangeMagnify: Paddings exist \n";
// dbg(msg);
return;
}
// HACK for conv5x5
vector<VarRangePair> newSumVarRanges;
for (const auto &[var, range] : cur->getSumVarRanges()) {
if (range.first == 0 && range.second == 5) {
newSumVarRanges.emplace_back(
var, Range{range.first, (range.second + 2) / 3 * 3});
} else
newSumVarRanges.emplace_back(var, range);
}
if (newSumVarRanges.empty())
return;
auto magnifiedCur = RangeMagnifyVisitor().magnify(cur, newSumVarRanges);
if (!magnifiedCur)
return;
// next searching step
string msg = "relax iterating ranges ";
for (size_t i = 0; i < cur->getSumVarRanges().size(); ++i) {
const auto &[v1, a] = cur->getVarRange(IterationType::Sum, i);
const auto &[v2, b] = magnifiedCur->getVarRange(IterationType::Sum, i);
assert(v1->getName() == v2->getName());
if (a != b) {
msg += v1->getName();
msg +=
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
msg += " (" + to_string(b.first) + "," + to_string(b.second) + "),";
}
}
nextStep(origin, depth, rCur, magnifiedCur, msg);
return;
}
} // namespace nnet

View File

@ -0,0 +1,165 @@
#include "nnet/Visitor/AsTVMVisitor.h"
namespace nnet {
std::string AsTVMVisitor::visit_(const Constant &c) {
return std::to_string(c->getValue());
}
std::string AsTVMVisitor::visit_(const BinaryOp &c) {
switch (c->getOpType()) {
case OpType::Add:
return "(" + dispatch(c->getLhs()) + " + " + dispatch(c->getRhs()) +
")";
case OpType::Sub:
return "(" + dispatch(c->getLhs()) + " - " + dispatch(c->getRhs()) +
")";
case OpType::Mul:
return "(" + dispatch(c->getLhs()) + " * " + dispatch(c->getRhs()) +
")";
case OpType::Div:
return "(" + dispatch(c->getLhs()) + " // " + dispatch(c->getRhs()) +
")";
case OpType::Mod:
return "(" + dispatch(c->getLhs()) + " % " + dispatch(c->getRhs()) +
")";
default:
assert(false);
}
}
std::string AsTVMVisitor::visit_(const Func &c) {
switch (c->getFuncType()) {
case FuncType::Relu:
// TODO: Deduce the dtype
return "te.max(" + dispatch(c->getObject()) +
", tvm.tir.const(0, 'float32'))";
case FuncType::Tanh:
return "te.tanh(" + dispatch(c->getObject()) + ")";
default:
assert(false);
}
}
std::string AsTVMVisitor::visit_(const RangeOp &c) {
auto outerStage = curStage;
curStage = nStage++;
std::string stmt;
std::string stageName = "s" + std::to_string(curStage);
std::vector<std::string> reduceVars;
for (auto &&[var, range] : c->getSumVarRanges()) {
std::string varName = stageName + "_" + var->getName();
stmt += varName + " = " + "te.reduce_axis((" +
std::to_string(range.first) + ", " +
std::to_string(range.second) + "), name=\"" + varName + "\")\n";
reduceVars.emplace_back(varName);
pythonVars.emplace_back(varName);
}
std::vector<int> shape;
stmt += stageName + " = te.compute((";
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
auto &&[var, range] = c->getLoopVarRanges()[i];
std::string varName = stageName + "_" + var->getName();
offset[varName] = -range.first + c->getPaddings(i);
auto len = range.second - range.first + 2 * c->getPaddings(i);
stmt += std::to_string(len) + ", ";
shape.emplace_back(len);
}
stmt += "), lambda ";
bool first = true;
for (auto &&[var, range] : c->getLoopVarRanges()) {
std::string varName = stageName + "_" + var->getName();
stmt += (first ? "" : ", ") + varName;
first = false;
}
std::string summand = dispatch(c->getSummand());
if (!reduceVars.empty()) {
summand = "te.sum(" + summand + ", axis=(";
for (auto &&var : reduceVars) {
summand += var + ", ";
}
summand += "))";
}
if (c->hasPaddings()) {
std::string guard = "tir.if_then_else(tir.all(";
bool first = true;
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
auto &&[var, range] = c->getLoopVarRanges()[i];
std::string varName = stageName + "_" + var->getName();
if (auto pad = c->getPaddings(i); pad > 0) {
guard += (first ? "" : ", ") + varName +
" >= " + std::to_string(range.first) + ", " + varName +
" < " + std::to_string(range.second);
first = false;
}
}
// TODO: Deduce the dtype
guard += "), " + summand + ", tvm.tir.const(0.0, \"float32\"))";
summand = guard;
}
stmt += ": " + summand + ")";
stmts += stmt + "\n";
pythonVars.emplace_back(stageName);
output = stageName;
outputShape = std::move(shape);
curStage = outerStage;
return stageName;
}
std::string AsTVMVisitor::visit_(const Subscript &c) {
std::string str = dispatch(c->getObject()) + "[";
for (size_t i = 0, n = c->getIndex().size(); i < n; i++) {
const auto &idx = c->getIndex()[i];
str += (i == 0 ? "" : ", ") + dispatch(idx);
if (c->getObject()->getType() == NodeType::RangeOpNodeType) {
auto rangeOp = as<RangeOpNode>(c->getObject());
str += " - " +
std::to_string(rangeOp->getLoopVarRanges()[i].second.first -
rangeOp->getPaddings(i));
}
}
str += "]";
return str;
}
std::string AsTVMVisitor::visit_(const Var &c) {
std::string stageName = "s" + std::to_string(curStage);
std::string varName = stageName + "_" + c->getName();
if (offset.count(varName)) {
return "(" + varName + " - " + std::to_string(offset.at(varName)) + ")";
} else {
return varName;
}
}
std::string AsTVMVisitor::visit_(const Tensor &c) {
pythonVars.emplace_back(c->getName());
inputs.emplace_back(c->getName());
inputShapes.emplace_back(c->getShape());
std::string stmt = c->getName() + " = te.placeholder((";
for (auto &&dim : c->getShape()) {
stmt += std::to_string(dim) + ", ";
}
stmt += "), name='" + c->getName() + "')";
stmts += stmt + "\n";
return c->getName();
}
std::string AsTVMVisitor::getStmts() const {
std::string ret;
// Workaround because closure capturing does not work in an `exec`
// https://stackoverflow.com/questions/2749655/why-are-closures-broken-within-exec
ret += "global ";
bool first = true;
for (auto &&var : pythonVars) {
ret += (first ? "" : ", ") + var;
first = false;
}
ret += "\n";
ret += stmts;
ret += "ret = [" + output;
for (auto &&input : inputs) {
ret += ", " + input;
}
ret += "]\n";
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,35 @@
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
void CheckOOBVisitor::visit_(const Subscript &c) {
const auto &objectRanges = c->getObjectRangesWithPaddings();
for (size_t dim = 0; dim < c->getDims(); ++dim) {
SimplifyExprVisitor simplifier;
auto optional = simplifier.getExprRange(c->getIndex(dim), rangeOp);
if (!optional.has_value())
continue;
const Range &exprRange = *optional;
if (exprRange.first < objectRanges[dim].first ||
exprRange.second > objectRanges[dim].second) {
// dbg("OOB detected!", c, dim, exprRange, objectRanges[dim]);
// std::cout << "OOB detected! " << c->toReadable() << ", dim=" <<
// dim
// << ", Range=(" << exprRange.first << ", "
// << exprRange.second << "), objRange=("
// << objectRanges[dim].first << ", "
// << objectRanges[dim].second << ")." << std::endl;
detect = true;
}
}
}
bool CheckOOBVisitor::checkRangeOp(const RangeOp &_rangeOp) {
detect = false;
rangeOp = _rangeOp;
dispatch(rangeOp);
return detect;
}
} // namespace nnet

View File

@ -0,0 +1,9 @@
#include "nnet/Visitor/CloneMutator.h"
namespace nnet {
Expr CloneMutator::visit_(const Constant &c) { return c; }
Expr CloneMutator::visit_(const Var &c) { return c; }
Expr CloneMutator::visit_(const Tensor &c) { return c; }
} // namespace nnet

View File

@ -0,0 +1,34 @@
#include "nnet/Visitor/CompareMultiFormulasVisitor.h"
namespace nnet {
bool CompareMultiFormulasVisitor::compare(const VecExpr &roots) {
if (roots.empty())
return false;
vector<RangeOp> rangeOps;
for (const auto &root : roots) {
if (auto rangeOp = as<RangeOpNode>(root))
rangeOps.emplace_back(rangeOp);
else
return false;
}
const auto pattern = rangeOps[0];
for (auto rangeOp : rangeOps) {
if (pattern->getNumOutputDims() != rangeOp->getNumOutputDims()) {
return false;
}
for (int i = 0; i < pattern->getNumOutputDims(); ++i)
if (pattern->getVarRange(0, i).second !=
rangeOp->getVarRange(0, i).second) {
return false;
}
for (size_t i = 0; i < pattern->getSumVarRanges().size(); ++i)
if (pattern->getVarRange(1, i).second !=
rangeOp->getVarRange(1, i).second) {
return false;
}
}
return true;
}
} // namespace nnet

View File

@ -0,0 +1,38 @@
#include "nnet/Visitor/CountRoutineVisitor.h"
namespace nnet {
void CountRoutineVisitor::visit_(const Tensor &c) {
if (auto routine = c->getSource(); routine) {
cnts[routineTypeToId(routine->getType())]++;
}
ExprTreeVisitor::visit_(c);
}
vector<int> CountRoutineVisitor::count(const Expr &root) {
cnts = vector<int>(RoutineTypeCnt, 0);
dispatch(root);
return cnts;
}
bool CountRoutineVisitor::match(const Expr &root, int nMatmul, int nConv,
int nElement, int nSg2bmm,
int nLongformerGBMM) {
auto opCount = count(root);
bool ret = true;
if (opCount[routineTypeToId(RoutineType::MatmulNodeType)] != nMatmul)
ret = false;
if (opCount[routineTypeToId(RoutineType::ConvNodeType)] != nConv)
ret = false;
if (opCount[routineTypeToId(RoutineType::ElementWiseNodeType)] != nElement)
ret = false;
if (opCount.at(routineTypeToId(RoutineType::G2bmmNodeType)) != nSg2bmm)
ret = false;
if (!ret) {
auto target =
vector<int>{nMatmul, nConv, nSg2bmm, nLongformerGBMM, nElement};
}
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,58 @@
#include "nnet/Visitor/FullPrinterVisitor.h"
namespace nnet {
void FullPrinterVisitor::visit_(const Tensor &c) {
q.emplace_back(c->getName(), c->getSource(), c);
}
string FullPrinterVisitor::print(const Expr &root) {
q.clear();
std::ostringstream oss;
dispatch(root);
oss << "==> ROOT\n" << root->toReadable() << "\n";
for (size_t i = 0; i < q.size(); ++i) {
const auto &[name, routine, tensor] = q[i];
oss << "==> " << name << " : ";
if (routine) {
oss << routine->toReadable() << "\n";
if (routine->getExpr()) {
oss << routine->getExpr()->toReadable() << "\n";
} else
oss << "[INFO] Source is nullptr \n";
if (!routine->getInputs().empty()) {
for (const auto &tensor : routine->getInputs())
q.emplace_back(tensor->getName(), tensor->getSource(),
tensor);
} else if (routine->getExpr())
dispatch(routine->getExpr());
} else
oss << "Input Tensor " << tensor->toOutputShape() << "\n";
}
return oss.str();
}
const vector<tuple<string, Routine, Tensor>> &
FullPrinterVisitor::traverse(const Expr &root) {
q.clear();
dispatch(root);
for (size_t i = 0; i < q.size(); ++i) {
const auto &[name, routine, tensor] = q[i];
if (routine) {
// Matmul after DLT do not modify expression, so inputs has a higher
// priority. Some OPs such as DLT have not implement source. Then
// use inputs
if (!routine->getInputs().empty()) {
for (const auto &tensor : routine->getInputs())
dispatch(tensor);
} else if (routine->getExpr()) {
dispatch(routine->getExpr());
} else {
assert(false);
}
}
}
return q;
}
} // namespace nnet

View File

@ -0,0 +1,9 @@
#include "nnet/Visitor/GetTensorsVisitor.h"
namespace nnet {
void GetTensorsVisitor::visit_(const Tensor &c) {
tensors.try_emplace(c->getName(), c);
}
} // namespace nnet

Some files were not shown because too many files have changed in this diff Show More