diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..d8a7dfe7 --- /dev/null +++ b/.clang-format @@ -0,0 +1,90 @@ +--- +Language: Cpp +# BasedOnStyle: LLVM +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: false +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: false +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IndentCaseLabels: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Right +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... + diff --git a/.cmake-format.json b/.cmake-format.json new file mode 100644 index 00000000..ba3498b9 --- /dev/null +++ b/.cmake-format.json @@ -0,0 +1,311 @@ +{ + "_help_parse": "Options affecting listfile parsing", + "parse": { + "_help_additional_commands": [ + "Specify structure for custom cmake functions" + ], + "additional_commands": { + "foo": { + "flags": [ + "BAR", + "BAZ" + ], + "kwargs": { + "HEADERS": "*", + "SOURCES": "*", + "DEPENDS": "*" + } + } + }, + "_help_override_spec": [ + "Override configurations per-command where available" + ], + "override_spec": {}, + "_help_vartags": [ + "Specify variable tags." + ], + "vartags": [], + "_help_proptags": [ + "Specify property tags." + ], + "proptags": [] + }, + "_help_format": "Options affecting formatting.", + "format": { + "_help_disable": [ + "Disable formatting entirely, making cmake-format a no-op" + ], + "disable": false, + "_help_line_width": [ + "How wide to allow formatted cmake files" + ], + "line_width": 1500, + "_help_tab_size": [ + "How many spaces to tab for indent" + ], + "tab_size": 2, + "_help_use_tabchars": [ + "If true, lines are indented using tab characters (utf-8", + "0x09) instead of space characters (utf-8 0x20).", + "In cases where the layout would require a fractional tab", + "character, the behavior of the fractional indentation is", + "governed by " + ], + "use_tabchars": false, + "_help_fractional_tab_policy": [ + "If is True, then the value of this variable", + "indicates how fractional indentions are handled during", + "whitespace replacement. If set to 'use-space', fractional", + "indentation is left as spaces (utf-8 0x20). If set to", + "`round-up` fractional indentation is replaced with a single", + "tab character (utf-8 0x09) effectively shifting the column", + "to the next tabstop" + ], + "fractional_tab_policy": "use-space", + "_help_max_subgroups_hwrap": [ + "If an argument group contains more than this many sub-groups", + "(parg or kwarg groups) then force it to a vertical layout." + ], + "max_subgroups_hwrap": 2, + "_help_max_pargs_hwrap": [ + "If a positional argument group contains more than this many", + "arguments, then force it to a vertical layout." + ], + "max_pargs_hwrap": 6, + "_help_max_rows_cmdline": [ + "If a cmdline positional group consumes more than this many", + "lines without nesting, then invalidate the layout (and nest)" + ], + "max_rows_cmdline": 2, + "_help_separate_ctrl_name_with_space": [ + "If true, separate flow control names from their parentheses", + "with a space" + ], + "separate_ctrl_name_with_space": false, + "_help_separate_fn_name_with_space": [ + "If true, separate function names from parentheses with a", + "space" + ], + "separate_fn_name_with_space": false, + "_help_dangle_parens": [ + "If a statement is wrapped to more than one line, than dangle", + "the closing parenthesis on its own line." + ], + "dangle_parens": false, + "_help_dangle_align": [ + "If the trailing parenthesis must be 'dangled' on its on", + "line, then align it to this reference: `prefix`: the start", + "of the statement, `prefix-indent`: the start of the", + "statement, plus one indentation level, `child`: align to", + "the column of the arguments" + ], + "dangle_align": "prefix", + "_help_min_prefix_chars": [ + "If the statement spelling length (including space and", + "parenthesis) is smaller than this amount, then force reject", + "nested layouts." + ], + "min_prefix_chars": 4, + "_help_max_prefix_chars": [ + "If the statement spelling length (including space and", + "parenthesis) is larger than the tab width by more than this", + "amount, then force reject un-nested layouts." + ], + "max_prefix_chars": 10, + "_help_max_lines_hwrap": [ + "If a candidate layout is wrapped horizontally but it exceeds", + "this many lines, then reject the layout." + ], + "max_lines_hwrap": 2, + "_help_line_ending": [ + "What style line endings to use in the output." + ], + "line_ending": "unix", + "_help_command_case": [ + "Format command names consistently as 'lower' or 'upper' case" + ], + "command_case": "canonical", + "_help_keyword_case": [ + "Format keywords consistently as 'lower' or 'upper' case" + ], + "keyword_case": "unchanged", + "_help_always_wrap": [ + "A list of command names which should always be wrapped" + ], + "always_wrap": [], + "_help_enable_sort": [ + "If true, the argument lists which are known to be sortable", + "will be sorted lexicographicall" + ], + "enable_sort": true, + "_help_autosort": [ + "If true, the parsers may infer whether or not an argument", + "list is sortable (without annotation)." + ], + "autosort": false, + "_help_require_valid_layout": [ + "By default, if cmake-format cannot successfully fit", + "everything into the desired linewidth it will apply the", + "last, most aggressive attempt that it made. If this flag is", + "True, however, cmake-format will print error, exit with non-", + "zero status code, and write-out nothing" + ], + "require_valid_layout": false, + "_help_layout_passes": [ + "A dictionary mapping layout nodes to a list of wrap", + "decisions. See the documentation for more information." + ], + "layout_passes": {} + }, + "_help_markup": "Options affecting comment reflow and formatting.", + "markup": { + "_help_bullet_char": [ + "What character to use for bulleted lists" + ], + "bullet_char": "*", + "_help_enum_char": [ + "What character to use as punctuation after numerals in an", + "enumerated list" + ], + "enum_char": ".", + "_help_first_comment_is_literal": [ + "If comment markup is enabled, don't reflow the first comment", + "block in each listfile. Use this to preserve formatting of", + "your copyright/license statements." + ], + "first_comment_is_literal": false, + "_help_literal_comment_pattern": [ + "If comment markup is enabled, don't reflow any comment block", + "which matches this (regex) pattern. Default is `None`", + "(disabled)." + ], + "literal_comment_pattern": ".*", + "_help_fence_pattern": [ + "Regular expression to match preformat fences in comments", + "default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``" + ], + "fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$", + "_help_ruler_pattern": [ + "Regular expression to match rulers in comments default=", + "``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``" + ], + "ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$", + "_help_explicit_trailing_pattern": [ + "If a comment line matches starts with this pattern then it", + "is explicitly a trailing comment for the preceding argument.", + "Default is '#<'" + ], + "explicit_trailing_pattern": "#<", + "_help_hashruler_min_length": [ + "If a comment line starts with at least this many consecutive", + "hash characters, then don't lstrip() them off. This allows", + "for lazy hash rulers where the first hash char is not", + "separated by space" + ], + "hashruler_min_length": 10, + "_help_canonicalize_hashrulers": [ + "If true, then insert a space between the first hash char and", + "remaining hash chars in a hash ruler, and normalize its", + "length to fill the column" + ], + "canonicalize_hashrulers": true, + "_help_enable_markup": [ + "enable comment markup parsing and reflow" + ], + "enable_markup": true + }, + "_help_lint": "Options affecting the linter", + "lint": { + "_help_disabled_codes": [ + "a list of lint codes to disable" + ], + "disabled_codes": [], + "_help_function_pattern": [ + "regular expression pattern describing valid function names" + ], + "function_pattern": "[0-9a-z_]+", + "_help_macro_pattern": [ + "regular expression pattern describing valid macro names" + ], + "macro_pattern": "[0-9A-Z_]+", + "_help_global_var_pattern": [ + "regular expression pattern describing valid names for", + "variables with global (cache) scope" + ], + "global_var_pattern": "[A-Z][0-9A-Z_]+", + "_help_internal_var_pattern": [ + "regular expression pattern describing valid names for", + "variables with global scope (but internal semantic)" + ], + "internal_var_pattern": "_[A-Z][0-9A-Z_]+", + "_help_local_var_pattern": [ + "regular expression pattern describing valid names for", + "variables with local scope" + ], + "local_var_pattern": "[a-z][a-z0-9_]+", + "_help_private_var_pattern": [ + "regular expression pattern describing valid names for", + "privatedirectory variables" + ], + "private_var_pattern": "_[0-9a-z_]+", + "_help_public_var_pattern": [ + "regular expression pattern describing valid names for public", + "directory variables" + ], + "public_var_pattern": "[A-Z][0-9A-Z_]+", + "_help_argument_var_pattern": [ + "regular expression pattern describing valid names for", + "function/macro arguments and loop variables." + ], + "argument_var_pattern": "[a-z][a-z0-9_]+", + "_help_keyword_pattern": [ + "regular expression pattern describing valid names for", + "keywords used in functions or macros" + ], + "keyword_pattern": "[A-Z][0-9A-Z_]+", + "_help_max_conditionals_custom_parser": [ + "In the heuristic for C0201, how many conditionals to match", + "within a loop in before considering the loop a parser." + ], + "max_conditionals_custom_parser": 2, + "_help_min_statement_spacing": [ + "Require at least this many newlines between statements" + ], + "min_statement_spacing": 1, + "_help_max_statement_spacing": [ + "Require no more than this many newlines between statements" + ], + "max_statement_spacing": 2, + "max_returns": 6, + "max_branches": 12, + "max_arguments": 5, + "max_localvars": 15, + "max_statements": 50 + }, + "_help_encode": "Options affecting file encoding", + "encode": { + "_help_emit_byteorder_mark": [ + "If true, emit the unicode byte-order mark (BOM) at the start", + "of the file" + ], + "emit_byteorder_mark": false, + "_help_input_encoding": [ + "Specify the encoding of the input file. Defaults to utf-8" + ], + "input_encoding": "utf-8", + "_help_output_encoding": [ + "Specify the encoding of the output file. Defaults to utf-8.", + "Note that cmake only claims to support utf-8 so be careful", + "when using anything else" + ], + "output_encoding": "utf-8" + }, + "_help_misc": "Miscellaneous configurations options.", + "misc": { + "_help_per_command": [ + "A dictionary containing any per-command configuration", + "overrides. Currently only `command_case` is supported." + ], + "per_command": {} + } +} diff --git a/.github/workflows/clang-format-check.yml b/.github/workflows/clang-format-check.yml new file mode 100644 index 00000000..24ba1a27 --- /dev/null +++ b/.github/workflows/clang-format-check.yml @@ -0,0 +1,19 @@ +name: clang-format Check +on: [pull_request] +jobs: + formatting-check: + name: Formatting Check + runs-on: ubuntu-latest + strategy: + matrix: + path: + - 'include' + - 'src' + - 'test' + steps: + - uses: actions/checkout@v2 + - name: Run clang-format style check for C/C++/Protobuf programs. + uses: jidicula/clang-format-action@v4.8.0 + with: + clang-format-version: '14' + check-path: ${{ matrix.path }} diff --git a/.gitignore b/.gitignore index 259148fa..6d512811 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,8 @@ *.exe *.out *.app + +build/ +build_debug/ + +.vscode/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..2f905b6a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "3rd-party/pybind11"] + path = 3rd-party/pybind11 + url = git@github.com:pybind/pybind11.git +[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"] + path = 3rd-party/nlohmann_json_cmake_fetchcontent + url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git +[submodule "3rd-party/googletest"] + path = 3rd-party/googletest + url = git@github.com:google/googletest.git diff --git a/3rd-party/googletest b/3rd-party/googletest new file mode 160000 index 00000000..e2239ee6 --- /dev/null +++ b/3rd-party/googletest @@ -0,0 +1 @@ +Subproject commit e2239ee6043f73722e7aa812a459f54a28552929 diff --git a/3rd-party/nlohmann_json_cmake_fetchcontent b/3rd-party/nlohmann_json_cmake_fetchcontent new file mode 160000 index 00000000..6aebf092 --- /dev/null +++ b/3rd-party/nlohmann_json_cmake_fetchcontent @@ -0,0 +1 @@ +Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756 diff --git a/3rd-party/pybind11 b/3rd-party/pybind11 new file mode 160000 index 00000000..1e3400b6 --- /dev/null +++ b/3rd-party/pybind11 @@ -0,0 +1 @@ +Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..f1079f65 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,91 @@ +# TODO: check the minimum cmake version +cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP) +include(CMakeDependentOption) +project(InfiniTensor C CXX) + +# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. +option(BUILD_TEST "Build tests" ON) +cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) +cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF) +cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF) + +set(DEFAULT_BUILD_TYPE "RelWithDebInfo") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion + +find_package( + Python + COMPONENTS Interpreter Development + REQUIRED) +find_package(CUDA REQUIRED) +# OpenMP +find_package(OpenMP) +if(OpenMP_C_FOUND) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") +endif() +if(OpenMP_CXX_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +endif() + +include_directories(include) + +# # Pybind11 +# add_subdirectory(3rd-party/pybind11) +# include_directories(3rd-party/pybind11/include) + +# nlohmann_json +add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) +include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) + +if(BUILD_TEST) + set(BUILD_GMOCK + OFF + CACHE BOOL "Do not build gmock" FORCE) + set(INSTALL_GTEST + OFF + CACHE BOOL "Do not install gtest" FORCE) + add_subdirectory(3rd-party/googletest) + include_directories(3rd-party/googletest/googletest/include) +endif() + +file(GLOB_RECURSE SRC src/*.cc src/*.cu) +# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc) +# list(REMOVE_ITEM SRC ${TEST} ${FFI}) + +add_library(InfiniTensor SHARED ${SRC}) +# Target +# cuda_add_library(it SHARED ${SRC}) +# cuda_add_cublas_to_target(it) # cublas +# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed) + +# # Python bindings +# pybind11_add_module(infini MODULE ${FFI}) +# target_link_libraries(infini PRIVATE infini_cpp) + +function(build_test files) + # Non-recursive glob for skip failed tests + file(GLOB TEST_SOURCES ${files}) + foreach(testsourcefile ${TEST_SOURCES}) + get_filename_component(testname ${testsourcefile} NAME_WE) + add_executable(${testname} ${testsourcefile}) + target_link_libraries(${testname} InfiniTensor GTest::gtest_main) + add_test(NAME ${testname} COMMAND ${testname}) + endforeach(testsourcefile ${TEST_SOURCES}) +endfunction() + +if(BUILD_TEST) + enable_testing() + if(BUILD_TEST_CORE) + build_test(test/core/*.cc) + endif() + if(BUILD_TEST_PET) + build_test(test/pet/*.cc) + endif() + if(BUILD_TEST_EINNET) + build_test(test/nnet/*.cc) + endif() +endif() diff --git a/include/core/common.h b/include/core/common.h new file mode 100644 index 00000000..0fe7344e --- /dev/null +++ b/include/core/common.h @@ -0,0 +1,61 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini { +using std::list; +using std::map; +using std::pair; +using std::set; +using std::string; +using std::tie; +using std::to_string; +using std::tuple; +using std::unordered_map; +using std::vector; + +// Aliases +using dtype = float; +using HashType = size_t; // compatible with std::hash + +// Metaprogramming utilities +#define _CAT(A, B) A##B +#define _SELECT(NAME, NUM) _CAT(NAME##_, NUM) +#define _GET_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ...) COUNT +#define _VA_SIZE(...) _GET_COUNT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) +#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__) + +// Assert: conditions should have no side effect +#define _IT_ASSERT_2(name, info) \ + (static_cast(name) \ + ? void(0) \ + : throw std::runtime_error( \ + std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ + "] Assertion failed (" + #name + "): " + #info)) +#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, ""); + +#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) +#define IT_TODO_HALT() IT_ASSERT(false, "Unimplemented") +#define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__) + +// Other utilities + +// std::to_underlying is avaiable since C++23 +template auto enum_to_underlying(T e) { + return static_cast>(e); +} + +double timeit(const std::function &func); + +} // namespace infini diff --git a/include/core/graph.h b/include/core/graph.h new file mode 100644 index 00000000..9c87310a --- /dev/null +++ b/include/core/graph.h @@ -0,0 +1,47 @@ +#pragma once +#include "core/operator.h" +#include "core/tensor.h" + +namespace infini { + +// TODO: graph should be attached to a context +class GraphNode : public Object { + protected: + TensorVec tensors; + TensorVec inputs; + TensorVec outputs; + OpVec ops; + + public: + // Graph(OpVec oplist); + string toString() const override; + + void addOp(Operator op) { ops.push_back(op); }; + const TensorVec &getTensors() const { return tensors; } + const TensorVec &getInputs() const { return inputs; } + const TensorVec &getOutputs() const { return outputs; } + const OpVec &getOperators() const { return ops; } + // TensorVec &getInputs(); + // TensorVec &getOutputs(); + + Tensor addTensor(Shape dim, DataType dtype = DataType::Int32) { + Tensor tensor = make_ref(dim, dtype); + tensors.emplace_back(tensor); + return tensor; + } + + void dataMalloc(); + + private: + // TODO: updateConnection + /** + * @brief Add reverse connections and Op relationship in ctor. + */ + void updateConnection(); + + // TODO: move to another class + // bool exportOnnx(const char *path); + // bool importOnnx(const char *net); +}; + +} // namespace infini diff --git a/include/core/kernel.h b/include/core/kernel.h new file mode 100644 index 00000000..4ecb2065 --- /dev/null +++ b/include/core/kernel.h @@ -0,0 +1,76 @@ +#pragma once +#include "core/common.h" +#include "core/operator.h" +#include "core/tensor.h" + +namespace infini { + +struct PerfRecord { + double time; // in milliseconds +}; + +class Kernel { + public: + Kernel() {} + virtual ~Kernel() {} + + /** + * @param op The operator to be executed. + * @param record The parameters for kernel execution. If extra parameters + * are required, inherit from PerfRecord and add extra parameters. + * Otherwire, use PerfRecord directly. + */ + virtual void compute(const Operator &op, + const PerfRecord &record) const = 0; + /** + * @brief Executes an op with a default parameter. + */ + virtual void compute(const Operator &op) const = 0; + // Premise: op is idempotent since it is called multiple times. + virtual PerfRecord tune(const Operator &op) const = 0; +}; + +class KernelRegistry { + public: + using KernelRecord = + tuple; // Kernel, name, ID + + private: + std::map kernels; + int nKernels = 0; + + public: + ~KernelRegistry() { + for (auto &[k, v] : kernels) + delete std::get<0>(v); + } + static KernelRegistry &getInstance() { + static KernelRegistry instance; + return instance; + } + bool registerKernel(const KernelAttrs &key, Kernel *kernel, string name) { + // TODO: mutliple kernels support: priority and check name + IT_ASSERT(kernels.find(key) == kernels.end(), + "Kernel already registered"); + kernels.emplace(key, KernelRecord{kernel, name, ++nKernels}); + return true; + } + Kernel *getKernel(const KernelAttrs &kernelAttrs) const { + return std::get<0>(kernels.at(kernelAttrs)); + } + const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { + return kernels.at(kernelAttrs); + } +}; + +} // namespace infini + +#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \ + namespace infini { \ + static const bool _CAT(_register_kernel_, cnt) = \ + KernelRegistry::getInstance().registerKernel( \ + KernelAttrs{device, opType, dataType}, new kernel(), name); \ + } + +#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \ + _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__) diff --git a/include/core/mutator.h b/include/core/mutator.h new file mode 100644 index 00000000..42402151 --- /dev/null +++ b/include/core/mutator.h @@ -0,0 +1,19 @@ +#pragma once +#include "core/graph.h" + +namespace infini { + +class Mutator { + private: + int candidatesLimit; + // // Statistical data + // int numTotalCandidates; + + public: + Mutator(int candidatesLimit) : candidatesLimit(candidatesLimit){}; + virtual ~Mutator(){}; + + virtual vector run(const Graph &in_graph) = 0; +}; + +} // namespace infini diff --git a/include/core/object.h b/include/core/object.h new file mode 100644 index 00000000..0faec9f5 --- /dev/null +++ b/include/core/object.h @@ -0,0 +1,54 @@ +#pragma once +#include "core/common.h" +#include "ref.h" + +namespace infini { + +using GuidBaseType = int; + +class Guid { + private: + GuidBaseType guid; + + private: + GuidBaseType generateGuid() { + static GuidBaseType guidCnt = 0; + return ++guidCnt; + } + + public: + Guid() { guid = generateGuid(); } + Guid(const Guid &rhs) { guid = generateGuid(); } + Guid &operator=(const Guid &rhs) { + guid = generateGuid(); + return *this; + } + + operator GuidBaseType() const { return guid; } +}; + +class Object { + protected: + Guid guid; + + public: + virtual ~Object(){}; + virtual string toString() const = 0; + void print() { std::cout << toString() << std::endl; } + Guid getGuid() const { return guid; } +}; + +inline std::ostream &operator<<(std::ostream &os, const Object &obj) { + os << obj.toString(); + return os; +} + +// Overload for Ref-wrapped Object +template > * = nullptr> +inline std::ostream &operator<<(std::ostream &os, const Ref &obj) { + os << obj->toString(); + return os; +} + +} // namespace infini \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h new file mode 100644 index 00000000..a24f240d --- /dev/null +++ b/include/core/operator.h @@ -0,0 +1,180 @@ +#pragma once +#include "core/tensor.h" + +namespace infini { + +enum class OpType { + Unknown = 0, + // linear + Conv = 100, + Matmul, + ConvTrans, + G2BMM, + GBMML, + Pad, + Slice, + Concat, + Split, + Transpose, + Extend, + MaxPool, + AvgPool, + Add, + Sub, + Mul, + Div, + Pow, + Gather, + ReduceMean, + Reshape, + Identity, + // element wise + BatchNorm = 200, + Softmax, + Activation, + Resize, + // + MemBound = 300, +}; + +enum class Device { CPU = 1, CUDA }; + +using KernelAttrs = std::tuple; + +class OpRegistry { + public: + static std::string getOpName(OpType opType) { +#define FOP(op) \ + case OpType::op: \ + return #op + + switch (opType) { + FOP(Unknown); + // linear + FOP(Conv); + FOP(Matmul); + FOP(ConvTrans); + FOP(G2BMM); + FOP(GBMML); + FOP(Pad); + FOP(Slice); + FOP(Concat); + FOP(Split); + FOP(Transpose); + FOP(Extend); + FOP(MaxPool); + FOP(AvgPool); + FOP(Add); + FOP(Sub); + FOP(Mul); + FOP(Div); + FOP(Pow); + FOP(Gather); + FOP(ReduceMean); + FOP(Reshape); + FOP(Identity); + // element wise + FOP(BatchNorm); + FOP(Softmax); + FOP(Activation); + // + FOP(MemBound); + default: + IT_ASSERT(false); + break; + } +#undef FOP + } +}; + +enum class ActType { + None, + Relu, + Sigmoid, + Tanh, +}; + +struct OpPerfKey { + HashType hash; + OpType opType; + vector attrs; + + public: + OpPerfKey(HashType hash, OpType opType, vector attrs = {}) + : hash(hash), opType(opType), attrs(attrs) {} + bool operator==(const OpPerfKey &rhs) const { + if (hash != rhs.hash) + return false; + if (opType != rhs.opType) + return false; + if (attrs != rhs.attrs) + return false; + return true; + } + + // TODO: remove this function after we use unordered_map in PerfEngine + bool operator<(const OpPerfKey &rhs) const { + if (hash != rhs.hash) + return hash < rhs.hash; + if (opType != rhs.opType) + return opType < rhs.opType; + if (attrs.size() != rhs.attrs.size()) + return attrs.size() < rhs.attrs.size(); + for (size_t i = 0; i < attrs.size(); ++i) + if (attrs[i] != rhs.attrs[i]) + return attrs[i] < rhs.attrs[i]; + return false; + } +}; + +class OperatorNode : public Object { + friend class Kernel; + + protected: + OpType type; + TensorVec inputs; + TensorVec outputs; + // vector> predecessors; + // vector> successors; + + public: + OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs) + : type(opType), inputs(inputs), outputs(outputs) {} + virtual vector computeShape() const = 0; + virtual OpPerfKey getOpPerfKey() const = 0; + + public: // check Op type + bool isLinearOp() const; + bool isElementWiseOp() const; + bool isSplitOp() const; + bool isConcatOp() const; + bool isComputeOp() const; + bool isTransposeOp() const; + bool isReshapeOp() const; + bool isMemBoundOp() const; + + public: // getter and setter + // TensorVec getInputs() { return inputs; } + const TensorVec &getInputs() const { return inputs; } + // TensorVec getOutputs() { return outputs; } + const TensorVec &getOutputs() const { return outputs; } + Tensor getInputs(size_t i) { return inputs.at(i); } + Tensor getOutput() const { + IT_ASSERT(outputs.size() == 1, "Unimplemented"); + return outputs[0]; + } + OpType getOpType() const { return type; } + + virtual int numInputs() const = 0; + virtual int numOutputs() const = 0; + virtual HashType hash() const { IT_TODO_HALT(); } + virtual HashType hashWithShape() const { IT_TODO_HALT(); } +}; + +} // namespace infini + +namespace std { +template <> struct hash { + size_t operator()(const infini::OpPerfKey &key) const { return key.hash; } +}; +} // namespace std \ No newline at end of file diff --git a/include/core/perf_engine.h b/include/core/perf_engine.h new file mode 100644 index 00000000..563ad704 --- /dev/null +++ b/include/core/perf_engine.h @@ -0,0 +1,36 @@ +#pragma once +#include "core/graph.h" +#include "core/kernel.h" + +namespace infini { + +class PerfEngine { + public: + // TODO: Key should be OpPerfKey + Context(maybe implicat) to support + // multiple candiate kernels. + using Key = std::pair; + + private: + map data; + + public: + static PerfEngine &getInstance() { + static PerfEngine instance; + return instance; + } + + std::optional getPerfData(const Key &key) { + auto it = data.find(key); + if (it != data.end()) // find previous evaluating results + return data.at(key); + else + return std::nullopt; + } + + void setPerfData(const Key &key, const PerfRecord &record) { + IT_ASSERT(data.find(key) == data.end(), "Perf data already exist"); + data.emplace(key, record); + } +}; + +} // namespace infini \ No newline at end of file diff --git a/include/core/ref.h b/include/core/ref.h new file mode 100644 index 00000000..f5ba4e89 --- /dev/null +++ b/include/core/ref.h @@ -0,0 +1,35 @@ +#pragma once +#include // hash +#include +#include + +namespace infini { + +template using Ref = std::shared_ptr; +template using WRef = std::weak_ptr; + +template struct is_ref : std::false_type {}; +template struct is_ref> : std::true_type {}; +template struct is_ref> : std::true_type {}; + +template Ref make_ref(Params &&...params) { + static_assert(is_ref::value == false, "Ref should not be nested"); + return std::make_shared(std::forward(params)...); +} + +template > * = nullptr> +Ref as(const Ref &ref) { + return std::dynamic_pointer_cast(ref); +} + +template +std::vector> get_wref_vec(const std::vector> &vec) { + std::vector> wref_vec; + wref_vec.reserve(vec.size()); + for (const auto &ref : vec) + wref_vec.emplace_back(ref); + return wref_vec; +} + +} // namespace infini \ No newline at end of file diff --git a/include/core/run_enigne.h b/include/core/run_enigne.h new file mode 100644 index 00000000..1b4877a1 --- /dev/null +++ b/include/core/run_enigne.h @@ -0,0 +1,26 @@ +#pragma once +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" + +namespace infini { + +class RunEngine { + private: + Device device; + + public: + RunEngine(Device device) : device(device) {} + ~RunEngine() {} + + void run(const Graph &graph, bool tune = false, + bool profiling = false) const; + double getPerfTime(const Graph &graph, bool profiling = false) const; + + private: + void printProfilingData(double totTime, + const std::map &opTime, + const std::map &opCnt) const; +}; + +} // namespace infini \ No newline at end of file diff --git a/include/core/tensor.h b/include/core/tensor.h new file mode 100644 index 00000000..67544753 --- /dev/null +++ b/include/core/tensor.h @@ -0,0 +1,180 @@ +#pragma once +#include "core/tensor_base.h" + +namespace infini { + +// TODO: how to deal with this +using ShapeElem = int; +using Shape = vector; +class TensorNode : public TensorBaseNode { + private: + Shape shape; + + public: + TensorNode(const Shape &shape, DataType dtype); + virtual ~TensorNode() {} + string toString() const override; + + size_t size() const; + void dataMalloc(); + + Shape getDims() const { return shape; } + + size_t getOffset(const Shape &ds) const; + using TensorBaseNode::getData; + VType getData(const Shape &pos) const; + void copyData(VType *dptr); + void printData() const; + bool equalData(const Tensor &rhs) const; + // void setDims(const Dim &dms) { dims = dms; } + + // bool dataRand(int seed = 0) { + // if (data == nullptr) + // data = new VType[size()]; + // if (!random_inited) + // initFastrand(); + // // srand(seed); + // // faster rand generator; parallel + // size_t iEnd = size(); + // // std::cerr << "Init beginned " << std::endl; + // #pragma omp parallel for + // for (size_t i = 0; i < iEnd; ++i) + // data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) % + // 10000; + // // std::cerr << "Init finished" << std::endl; + // computed = ComputedFull; + // return true; + // } + + // bool setScalar(VType val) { + // if (data == nullptr || !dims.empty()) + // return false; + // data[0] = val; + // return true; + // } + + // bool setData(const Dim &ds, VType val) { + // if (data == nullptr || ds.size() != dims.size()) + // return false; + // data[getOffset(ds)] = val; + // return true; + // } + + // bool setData(size_t pos, VType val) { + // if (data == nullptr || pos >= size()) + // return false; + // data[pos] = val; + // return true; + // } + + // VType getScalar() { return data == nullptr ? 0 : data[0]; } + + // VType getBroadcastData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getBroadcastOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getBroadcastData(size_t pos) { + // assert(data != nullptr); + // return data[pos % size()]; + // } + + // size_t getBroadcastOffset(const Dim &ds) { + // assert(ds.size() >= dims.size()); + // auto nDim = dims.size(); + // auto nBroadcastDim = ds.size() - nDim; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >= + // dims[i]) + // return (size_t)-1; + // size_t idx = 0; + // for (size_t i = 0; i < nDim; ++i) + // idx = idx * dims[i] + ds[nBroadcastDim + i]; + // return idx; + // } + + // void itInit() { it = Dim(dims.size(), 0); } + + // void itReset() { + // itInit(); + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // it[i] = 0; + // } + + // bool itValid() { + // if (it.size() != dims.size()) + // return false; + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // if (it[i] >= dims[i]) + // return false; + // return true; + // } + + // const Dim &itGet() { return it; } + + // void itNext() { + // auto p = it.size() - 1; + // it[p] += 1; + // while (p >= 1) { + // if (it[p] == dims[p]) { + // it[p] = 0; + // it[--p] += 1; + // } else + // break; + // } + // } + + // TensorType getType() const { return type; } + // void setType(TensorType ty) { type = ty; } + + // static inline void initFastrand() { + // assert(omp_get_max_threads() <= 256); + // // srand(0); // constant seed for test + // // align random_seed to avoid false sharing + // for (int i = 0; i < 256 * 16; ++i) { + // // random_seed[i] = rand(); + // // constant random seed for test + // random_seed[i] = i; + // } + // random_inited = true; + // } + + // static inline int fastrand(int &g_seed) { + // g_seed = (214013 * g_seed + 2531011); + // return (g_seed >> 16) & 0x7FFF; + // } + + // std::vector> const *getSplittingPoints() const { + // assert(!splittingPoints.empty()); + // return &splittingPoints; + // } + + // bool setSplittingPoints(std::vector> value) { + // assert(!value.empty()); + // splittingPoints = value; + // return true; + // } + + // void printSplittingPoints() { + // if (splittingPoints.empty()) + // printf("Empty SplittingPoints"); + // else { + // printf("["); + // for (auto &vs : splittingPoints) { + // printf("["); + // for (auto v : vs) + // printf("%2d,", v); + // printf("],"); + // } + // printf("]"); + // } + // } + + // void initSplittingPoints() { + // splittingPoints.resize(getDims().size()); } + + // void printShape(); +}; + +} // namespace infini \ No newline at end of file diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h new file mode 100644 index 00000000..cafea062 --- /dev/null +++ b/include/core/tensor_base.h @@ -0,0 +1,261 @@ +#pragma once +#include "core/object.h" +#include "core/ref.h" + +namespace infini { + +// class Tensor; +class TensorBaseNode; +class TensorNode; +class OperatorNode; +class GraphNode; + +using TensorBase = Ref; +using Tensor = Ref; +using Operator = Ref; +using Graph = Ref; + +using TensorVec = vector; +using OpVec = vector; + +using VType = uint32_t; + +enum class DataType { + Float32, + Int32, +}; + +class TensorBaseNode : public Object { + public: + // enum TensorType { + // Input, + // Weight, + // Invalid, + // NotCounted, + // }; + + protected: + int dim; + + DataType dtype; + vector> inputOf; + WRef outputOf; + // TODO: Ref -> Ref + Ref data; + // ComputeState computed; + // static int random_seed[256 * 16]; + // static bool random_inited; + + public: + TensorBaseNode(int dim, DataType dtype); + virtual ~TensorBaseNode() {} + + Ref getDataPtr() const { return data; } + VType getData(size_t offset) const; + + DataType getDType() const { return dtype; } + + // uint64_t getHash() const { return hash; } + + // void setInputOf(const OpVec &ops) { + // inputOf.clear(); + // for (const auto &op : ops) + // inputOf.emplace_back(op); + // } + // void addInputOf(Operator op) { inputOf.emplace_back(op); } + // void setOutputOf(Operator op) { outputOf = op; } + + // const OpVec &getInputOf() { return inputOf; } + // Operator *getOutputOf() { return outputOf; } + // std::pair getOutputOfWithIndex(); + + // const Dim &getDims() const { return dims; } + // void setDims(const Dim &dms) { dims = dms; } + + // bool dataRand(int seed = 0) { + // if (data == nullptr) + // data = new VType[size()]; + // if (!random_inited) + // initFastrand(); + // // srand(seed); + // // faster rand generator; parallel + // size_t iEnd = size(); + // // std::cerr << "Init beginned " << std::endl; + // #pragma omp parallel for + // for (size_t i = 0; i < iEnd; ++i) + // data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) % + // 10000; + // // std::cerr << "Init finished" << std::endl; + // computed = ComputedFull; + // return true; + // } + + // bool setScalar(VType val) { + // if (data == nullptr || !dims.empty()) + // return false; + // data[0] = val; + // return true; + // } + + // bool setData(const Dim &ds, VType val) { + // if (data == nullptr || ds.size() != dims.size()) + // return false; + // data[getOffset(ds)] = val; + // return true; + // } + + // bool setData(size_t pos, VType val) { + // if (data == nullptr || pos >= size()) + // return false; + // data[pos] = val; + // return true; + // } + + // VType getScalar() { return data == nullptr ? 0 : data[0]; } + + // VType getData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getData(size_t pos) { + // assert(data != nullptr); + // assert(pos < size()); + // return data[pos]; + // } + + // VType *getDataPtr() const { return data; } + + // size_t getOffset(const Dim &ds) { + // auto nDim = ds.size(); + // assert(dims.size() == nDim); + // if (ds.empty()) + // return 0; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[i] < 0 || ds[i] >= dims[i]) + // return (size_t)-1; + // size_t idx = ds[0]; + // size_t dm = 0; + // while (++dm < nDim) + // idx = idx * dims[dm] + ds[dm]; + // return idx; + // } + + // VType getBroadcastData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getBroadcastOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getBroadcastData(size_t pos) { + // assert(data != nullptr); + // return data[pos % size()]; + // } + + // size_t getBroadcastOffset(const Dim &ds) { + // assert(ds.size() >= dims.size()); + // auto nDim = dims.size(); + // auto nBroadcastDim = ds.size() - nDim; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >= + // dims[i]) + // return (size_t)-1; + // size_t idx = 0; + // for (size_t i = 0; i < nDim; ++i) + // idx = idx * dims[i] + ds[nBroadcastDim + i]; + // return idx; + // } + + // void itInit() { it = Dim(dims.size(), 0); } + + // void itReset() { + // itInit(); + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // it[i] = 0; + // } + + // bool itValid() { + // if (it.size() != dims.size()) + // return false; + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // if (it[i] >= dims[i]) + // return false; + // return true; + // } + + // const Dim &itGet() { return it; } + + // void itNext() { + // auto p = it.size() - 1; + // it[p] += 1; + // while (p >= 1) { + // if (it[p] == dims[p]) { + // it[p] = 0; + // it[--p] += 1; + // } else + // break; + // } + // } + + // size_t size() const { + // size_t sz = 1; + // auto dm = dims.size(); + // while (dm > 0) + // sz *= dims[--dm]; + // return sz; + // } + + // TensorType getType() const { return type; } + // void setType(TensorType ty) { type = ty; } + + // static inline void initFastrand() { + // assert(omp_get_max_threads() <= 256); + // // srand(0); // constant seed for test + // // align random_seed to avoid false sharing + // for (int i = 0; i < 256 * 16; ++i) { + // // random_seed[i] = rand(); + // // constant random seed for test + // random_seed[i] = i; + // } + // random_inited = true; + // } + + // static inline int fastrand(int &g_seed) { + // g_seed = (214013 * g_seed + 2531011); + // return (g_seed >> 16) & 0x7FFF; + // } + + // std::vector> const *getSplittingPoints() const { + // assert(!splittingPoints.empty()); + // return &splittingPoints; + // } + + // bool setSplittingPoints(std::vector> value) { + // assert(!value.empty()); + // splittingPoints = value; + // return true; + // } + + // void printSplittingPoints() { + // if (splittingPoints.empty()) + // printf("Empty SplittingPoints"); + // else { + // printf("["); + // for (auto &vs : splittingPoints) { + // printf("["); + // for (auto v : vs) + // printf("%2d,", v); + // printf("],"); + // } + // printf("]"); + // } + // } + + // void initSplittingPoints() { + // splittingPoints.resize(getDims().size()); } + + // void printShape(); +}; + +} // namespace infini diff --git a/include/nnet/Pass/MatchComputationKernel.h b/include/nnet/Pass/MatchComputationKernel.h new file mode 100644 index 00000000..4b182a89 --- /dev/null +++ b/include/nnet/Pass/MatchComputationKernel.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class MatchComputationKernel : public Pass { + public: + MatchComputationKernel(Derivator &derivator) + : Pass(derivator, "MatchComputationKernel") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/MatchMemBoundKernel.h b/include/nnet/Pass/MatchMemBoundKernel.h new file mode 100644 index 00000000..6b0a4bec --- /dev/null +++ b/include/nnet/Pass/MatchMemBoundKernel.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class MatchMemBoundKernel : public Pass { + public: + MatchMemBoundKernel(Derivator &derivator) + : Pass(derivator, "MatchMemBoundKernel") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Pass.h b/include/nnet/Pass/Pass.h new file mode 100644 index 00000000..a8d517fa --- /dev/null +++ b/include/nnet/Pass/Pass.h @@ -0,0 +1,41 @@ +#pragma once +#include "nnet/derivator.h" + +namespace nnet { + +class Pass { + private: + VecExpr transformations; + + protected: + Derivator &derivator; + string passName; + /** + * @brief // False if does not add log in Derivator. It should be false for + * single Pass test to avoid mismatch of passInfos and passMsgs due to + * different number of "run" and "nextStep". + */ + bool enableLogging, enableDebug; + + virtual void transform(Formula &origin, int depth, Expr &rCur) = 0; + void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur, + const string &ruleInfo = ""); + + Var getNewVar(); + string newTensorName(); + + private: + void initialize(Formula &origin, const Expr &rCur); + void finalize(); + + public: + Pass(Derivator &derivator, const string &passName); + virtual ~Pass(); + + void run(Formula &origin, int dfsDepth, Expr &rCur); + void setEnableLogging(bool value); + void setEnableDebug(bool value); + const VecExpr &getTransformations(); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule1VariableSplit.h b/include/nnet/Pass/Rule1VariableSplit.h new file mode 100644 index 00000000..06f2bb4e --- /dev/null +++ b/include/nnet/Pass/Rule1VariableSplit.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/Pass/Pass.h" +#include "nnet/ReplaceKit.h" + +namespace nnet { + +class Rule1VariableSplit : public Pass { + public: + Rule1VariableSplit(Derivator &derivator) + : Pass(derivator, "Rule1VariableSplit") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + vector getSplitableVar(const RangeOp &rangeOp); + Expr replaceIters(Expr cur, const Replace &replace); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule2VariableMerging.h b/include/nnet/Pass/Rule2VariableMerging.h new file mode 100644 index 00000000..1f277f46 --- /dev/null +++ b/include/nnet/Pass/Rule2VariableMerging.h @@ -0,0 +1,29 @@ +#pragma once +#include "nnet/Pass/Pass.h" +#include "nnet/ReplaceKit.h" + +namespace nnet { + +class Rule2VariableMerging : public Pass { + private: + map> substituteRules; + + public: + Rule2VariableMerging(Derivator &derivator) + : Pass(derivator, "Rule2VariableMerging") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + + vector getMergableReplaces(RangeOp rangeOp, int depth); + optional getReplaceMergingTwoLoopIters(const RangeOp &rangeOp, + pair pairA, + pair pairB, + const IteratorTable &exprIT, + int tensorID); + optional getReplaceMappingTwoLoopIters(const RangeOp &rangeOp, + pair pa, + pair pb); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule3StageSplit.h b/include/nnet/Pass/Rule3StageSplit.h new file mode 100644 index 00000000..99e172cf --- /dev/null +++ b/include/nnet/Pass/Rule3StageSplit.h @@ -0,0 +1,19 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule3StageSplit : public Pass { + private: + map> substituteRules; + + public: + Rule3StageSplit(Derivator &derivator) + : Pass(derivator, "Rule3StageSplit") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + vector> getSplitSummationIters(RangeOp rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule4StageMerging.h b/include/nnet/Pass/Rule4StageMerging.h new file mode 100644 index 00000000..13f11074 --- /dev/null +++ b/include/nnet/Pass/Rule4StageMerging.h @@ -0,0 +1,22 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule4StageMerging : public Pass { + bool success, mergeStageWithCalc; + + public: + Rule4StageMerging(Derivator &derivator) + : Pass(derivator, "Rule4StageMerging"), success(false), + mergeStageWithCalc(false) {} + bool rule4StageMerging(Formula &origin, int depth, Expr &rCur, + bool mergeStageWithCalc = false); + bool isSuccessful(); + void setMergeStageWithCalc(bool value); + + private: + virtual void transform(Formula &origin, int depth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule5RangeRelaxation.h b/include/nnet/Pass/Rule5RangeRelaxation.h new file mode 100644 index 00000000..c1719dfc --- /dev/null +++ b/include/nnet/Pass/Rule5RangeRelaxation.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule5RangeRelaxation : public Pass { + public: + Rule5RangeRelaxation(Derivator &derivator) + : Pass(derivator, "Rule5RangeRelaxation") {} + Expr rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur); + + private: + virtual void transform(Formula &origin, int depth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule6KenerlMatching.h b/include/nnet/Pass/Rule6KenerlMatching.h new file mode 100644 index 00000000..19648eaf --- /dev/null +++ b/include/nnet/Pass/Rule6KenerlMatching.h @@ -0,0 +1,17 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule6KenerlMatching : public Pass { + public: + Rule6KenerlMatching(Derivator &derivator) + : Pass(derivator, "Rule6KenerlMatching") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + // RE: seperating this func is a choice. + VecExpr matchElementWise(const RangeOp &rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule7DLT.h b/include/nnet/Pass/Rule7DLT.h new file mode 100644 index 00000000..d2ce0ecc --- /dev/null +++ b/include/nnet/Pass/Rule7DLT.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule7DLT : public Pass { + public: + Rule7DLT(Derivator &derivator) : Pass(derivator, "Rule7DLT") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + Expr buildDLTSingleRangeOp(const RangeOp &original, const Expr &newSummand); + vector getFactors(); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule8GuidedDLT.h b/include/nnet/Pass/Rule8GuidedDLT.h new file mode 100644 index 00000000..e6536df4 --- /dev/null +++ b/include/nnet/Pass/Rule8GuidedDLT.h @@ -0,0 +1,48 @@ +#pragma once +#include "nnet/Pass/Pass.h" +#include "nnet/ReplaceKit.h" + +namespace nnet { + +class Rule8GuidedDLT : public Pass { + public: + Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {} + VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur, + bool debug = false); + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + /** + * @brief If only one row miss match (more iterators mismatch), directly do + * data layout construction according to the IT. + * + * @return Expr Return nullptr if failed. + */ + Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch, + const IteratorTable &exprIT, const Pattern &pattern); + /** + * @brief Check whether two iterators overlap each other. If overlapping, we + * cannot simply reconstruct the tensor into a new one by seperate all + * iterators into different dimensions. + */ + bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT, + int tensorID); + /** + * @brief Only product of two tensors can be guided DLTed. + * + * @param cur + * @return true + * @return false + */ + bool statisfyGuidedDLT(RangeOp cur) const; + /** + * @brief Deal with output DLT mismatch only. + */ + Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch, + const IteratorTable &exprIT, + const Pattern &pattern); + Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace, + vector tensorDimAxes, vector newShape); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule90TwoStageElementWise.h b/include/nnet/Pass/Rule90TwoStageElementWise.h new file mode 100644 index 00000000..ab37cf99 --- /dev/null +++ b/include/nnet/Pass/Rule90TwoStageElementWise.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule90TwoStageElementWise : public Pass { + public: + Rule90TwoStageElementWise(Derivator &derivator) + : Pass(derivator, "Rule90TwoStageElementWise") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; + VecExpr matchTwoStageElementWise(const RangeOp &rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule91MergeStagesWithSum.h b/include/nnet/Pass/Rule91MergeStagesWithSum.h new file mode 100644 index 00000000..221e1772 --- /dev/null +++ b/include/nnet/Pass/Rule91MergeStagesWithSum.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule91MergeStagesWithSum : public Pass { + public: + Rule91MergeStagesWithSum(Derivator &derivator) + : Pass(derivator, "Rule91MergeStagesWithSum") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Pass/Rule9RangeMagnify.h b/include/nnet/Pass/Rule9RangeMagnify.h new file mode 100644 index 00000000..613de158 --- /dev/null +++ b/include/nnet/Pass/Rule9RangeMagnify.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/Pass/Pass.h" + +namespace nnet { + +class Rule9RangeMagnify : public Pass { + public: + Rule9RangeMagnify(Derivator &derivator) + : Pass(derivator, "Rule9RangeMagnify") {} + + private: + virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/ReplaceKit.h b/include/nnet/ReplaceKit.h new file mode 100644 index 00000000..19f2ec0e --- /dev/null +++ b/include/nnet/ReplaceKit.h @@ -0,0 +1,46 @@ +#pragma once +#include "nnet/expr.h" + +namespace nnet { + +struct Replace { + int iteratorType; + vector oldIters; // i_1, ... + vector newIters; // j_1, ... + VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter + VecExpr psis; // i_1=\psi_1(j_1, ...) + vector newVarRanges; + + bool isReplaced(Var var) const { + for (const auto &iter : oldIters) + if (iter->equal(var)) + return true; + return false; + } + + string toReadable() const { + string ret = "Old iters: " + serializeVec(oldIters) + + ", new iters: " + serializeVec(newIters); + ret += " phis: " + serializeVec(phis) + " psis: " + serializeVec(psis); + return ret; + } +}; + +class ReplaceKit { + public: + static RangeOp replaceRangeOpIterator(const RangeOp &rangeOp, + const Replace &replace, + const Expr &replacedSummand); + static Subscript buildSubscirptForLoopVarReplace(const RangeOp &inner, + const Replace &replace); + static RangeOp buildDLTOuterRangeOp(const RangeOp &original, + const Subscript &subscriptedNewRangeOp); + static Expr replaceMultipleExprs(const Expr &cur, + const vector &patterns, + const VecExpr &replacements, + bool simplify = true); + static Expr replaceExpr(const Expr &cur, const Expr &pattern, + const Expr &replacement); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/AsTVMVisitor.h b/include/nnet/Visitor/AsTVMVisitor.h new file mode 100644 index 00000000..f193231a --- /dev/null +++ b/include/nnet/Visitor/AsTVMVisitor.h @@ -0,0 +1,38 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +class AsTVMVisitor : public Functor { + private: + int nStage = 0, curStage = -1; + std::unordered_map offset; + std::vector inputs; + std::string output; + std::vector pythonVars; + std::vector> inputShapes; + std::vector outputShape; + std::string stmts; + + public: + std::string getStmts() const; + + const std::vector &getInputs() const { return inputs; } + const std::string &getOutput() const { return output; } + + const std::vector> &getInputShapes() const { + return inputShapes; + } + const std::vector &getOutputShape() const { return outputShape; } + + std::string visit_(const Constant &c) override; + std::string visit_(const BinaryOp &c) override; + std::string visit_(const Func &c) override; + std::string visit_(const RangeOp &c) override; + std::string visit_(const Subscript &c) override; + std::string visit_(const Var &c) override; + std::string visit_(const Tensor &c) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CheckOOBVisitor.h b/include/nnet/Visitor/CheckOOBVisitor.h new file mode 100644 index 00000000..017a7ae9 --- /dev/null +++ b/include/nnet/Visitor/CheckOOBVisitor.h @@ -0,0 +1,22 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class CheckOOBVisitor : public ExprTreeVisitor { + RangeOp rangeOp; + bool detect = false; + + public: + CheckOOBVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 0, 0, _verobse) {} + void visit_(const Subscript &c) override; + + /** + * @brief + * @return true If there is OOB + * @return false If there is no OOB + */ + bool checkRangeOp(const RangeOp &_rangeOp); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CloneMutator.h b/include/nnet/Visitor/CloneMutator.h new file mode 100644 index 00000000..6812cb80 --- /dev/null +++ b/include/nnet/Visitor/CloneMutator.h @@ -0,0 +1,16 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Clone ExprNodes in a stage except Tensor, Var, and Constant nodes. +class CloneMutator : public Mutator { + public: + CloneMutator() : Mutator(false) {} + Expr visit_(const Constant &c) override; + Expr visit_(const Var &c) override; + Expr visit_(const Tensor &c) override; + Expr clone(const Expr &c) { return dispatch(c); } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CompareMultiFormulasVisitor.h b/include/nnet/Visitor/CompareMultiFormulasVisitor.h new file mode 100644 index 00000000..a120d48e --- /dev/null +++ b/include/nnet/Visitor/CompareMultiFormulasVisitor.h @@ -0,0 +1,15 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class CompareMultiFormulasVisitor : public ExprTreeVisitor { + vector newSumVarRanges; + RangeOp newRangeOp; + + public: + CompareMultiFormulasVisitor() : ExprTreeVisitor() {} + bool compare(const VecExpr &roots); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/CountRoutineVisitor.h b/include/nnet/Visitor/CountRoutineVisitor.h new file mode 100644 index 00000000..d9536495 --- /dev/null +++ b/include/nnet/Visitor/CountRoutineVisitor.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class CountRoutineVisitor : public ExprTreeVisitor { + private: + vector cnts; + + public: + CountRoutineVisitor(int _verobse = 0) + : ExprTreeVisitor(1, 1, 1, 1, _verobse) {} + void visit_(const Tensor &c) override; + vector count(const Expr &root); + bool match(const Expr &root, int nMatmul = 0, int nConv = 0, + int nElement = 0, int nSg2bmm = 0, int nLongformerGBMM = 0); +}; +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/FullPrinterVisitor.h b/include/nnet/Visitor/FullPrinterVisitor.h new file mode 100644 index 00000000..a9caeccf --- /dev/null +++ b/include/nnet/Visitor/FullPrinterVisitor.h @@ -0,0 +1,25 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class FullPrinterVisitor : public ExprTreeVisitor { + private: + vector> q; + + public: + FullPrinterVisitor(int _verobse = 0) + : ExprTreeVisitor(1, 1, 1, 0, _verobse) {} + void visit_(const Tensor &c) override; + + string print(const Expr &root); + /** + * @brief Get all tensors & OPs in a reversed order + * + * @param root + * @return vector<> + */ + const vector> &traverse(const Expr &root); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/GetTensorsVisitor.h b/include/nnet/Visitor/GetTensorsVisitor.h new file mode 100644 index 00000000..66a8d5dc --- /dev/null +++ b/include/nnet/Visitor/GetTensorsVisitor.h @@ -0,0 +1,22 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Get all tensors in the stage +class GetTensorsVisitor : public ExprTreeVisitor { + private: + unordered_map tensors; + + void visit_(const Tensor &c) override; + + public: + GetTensorsVisitor(int _verobse = 0) + : ExprTreeVisitor(1, 1, 1, 0, _verobse) {} + auto get(const Expr &c) { + dispatch(c); + return tensors; + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/HashVisitor.h b/include/nnet/Visitor/HashVisitor.h new file mode 100644 index 00000000..0d20f49d --- /dev/null +++ b/include/nnet/Visitor/HashVisitor.h @@ -0,0 +1,31 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Calculate hash for a normal form, starting at a RangeOp +class HashVisitor : public Functor { + inline const static HashType BKDR_SEED[] = {131, 313, 10007, 65599}; + + PtrUmap varHash; + int nLoopVars = 0; + PtrUmap name2id; + vector rootId; + vector haveAlias; + int nVars = 0; + vector power; + + private: + HashType visit_(const Constant &c) override; + HashType visit_(const BinaryOp &c) override; + HashType visit_(const RangeOp &c) override; + HashType visit_(const Subscript &c) override; + HashType visit_(const Tensor &c) override; + HashType visit_(const Var &c) override; + + public: + HashVisitor(int _verobse = 0) : Functor(_verobse) {} + HashType getHash(const Expr &c); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/InputVisitor.h b/include/nnet/Visitor/InputVisitor.h new file mode 100644 index 00000000..1bc3c89a --- /dev/null +++ b/include/nnet/Visitor/InputVisitor.h @@ -0,0 +1,23 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class InputVisitor : public ExprTreeVisitor { + vector inputs; + + public: + int nInputs = 0; + InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {} + void visit_(const Tensor &c) override; + + /** + * @brief Get the all inputs in the netsed stages + */ + vector getInputs(const RangeOp &_rangeOp) { + dispatch(_rangeOp); + return inputs; + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/Interpreter.h b/include/nnet/Visitor/Interpreter.h new file mode 100644 index 00000000..b4479b79 --- /dev/null +++ b/include/nnet/Visitor/Interpreter.h @@ -0,0 +1,55 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class Interpreter : public Functor { + public: + using ttype = int; // Test data type + using rtype = int; // Return data type + using Position = vector; + using Inputs = unordered_map>>; + using Iteration = PtrUmap; + + private: + // cache the input value + Inputs inputs; + vector iterations; + vector positions; + + rtype visit_(const Constant &c) override; + rtype visit_(const BinaryOp &c) override; + rtype visit_(const RangeOp &c) override; + rtype visit_(const Subscript &c) override; + rtype visit_(const Var &c) override; + rtype visit_(const Tensor &c) override; + // int visit_(const Func &c); // Future work + + static Inputs genInputStartingFromZero(const RangeOp &range); + + public: + Interpreter(Inputs _inputs, int _verbose = 0) + : Functor(_verbose), inputs(_inputs) {} + Interpreter(RangeOp range, int _verbose = 0); + + /** + * @brief Calculate the output at specified poistions + * + * @param expr The expression to be calculated. + * @param poses Positions of output. + * @return vector Value of output. + */ + vector interpret(const Expr &expr, const vector &poses); + /** + * @brief Calculate the output at equally spaced positions + * + * @param expr The expression to be calculated. + * @param nPoses The number of calculated output positions. + * @return vector Value of output. + */ + vector interpretUniformSample(const RangeOp &range, + int nPoses = 100); + vector interpretAllOutput(const RangeOp &range); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MatchReshapeVisitor.h b/include/nnet/Visitor/MatchReshapeVisitor.h new file mode 100644 index 00000000..caf0130e --- /dev/null +++ b/include/nnet/Visitor/MatchReshapeVisitor.h @@ -0,0 +1,14 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class MatchReshapeVisitor : public Functor { + private: + PtrMap _coefficient; + + public: + bool visit_(const RangeOp &c) override; +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MatchTableVisitor.h b/include/nnet/Visitor/MatchTableVisitor.h new file mode 100644 index 00000000..53836883 --- /dev/null +++ b/include/nnet/Visitor/MatchTableVisitor.h @@ -0,0 +1,60 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +class MatchTableVisitor + : public Functor stride)> { + private: + // Var -> {(tensor, dim)} + Appearance appearance; + vector tensors; + vector subscripts; + StrideTable strideTable; + PtrMap>> + strideInDim; // [Iterator][tensorID][dim]=stride + + // Intermediate variable + // product of a sub-exprtree: Stride has to be done in two DFS + SubexprSride subexprStride; + bool hasUnsupportedOp = false; + + public: + MatchTableVisitor(int _verobse = 0) : Functor(_verobse) {} + void visit_(const BinaryOp &c, const Tensor &tensor, int dim, + optional stride) override; + void visit_(const Subscript &c, const Tensor &tensor, int dim, + optional stride) override; + void visit_(const Var &c, const Tensor &tensor, int dim, + optional stride) override; + void visit_(const Constant &c, const Tensor &tensor, int dim, + optional stride) override; + // void visit_(const Tensor &c, const Tensor &tensor) override; + + [[nodiscard]] bool operator()(const RangeOp &e) { + hasUnsupportedOp = false; + // get the location and stride of each iterator + auto mulOp = as(e->getSummand()); + // TODO [feature]: support complex index exprs + if (!mulOp || mulOp->getOpType() != OpType::Mul) { + nnet_unimplemented_continue(); + return false; + } + StrideVisitor strideVisitor(0); + subexprStride = strideVisitor.getFormulaStride(e); + dispatch(mulOp->getLhs(), nullptr, 0, 0); + dispatch(mulOp->getRhs(), nullptr, 0, 0); + subscripts.emplace_back(as(mulOp->getLhs())); + subscripts.emplace_back(as(mulOp->getRhs())); + assert(tensors.size() == subscripts.size()); + assert(tensors.size() < 5); + return !hasUnsupportedOp; + } + + auto getResult() const { + return tuple(appearance, tensors, strideTable, subscripts); + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MatmulTransposeMutator.h b/include/nnet/Visitor/MatmulTransposeMutator.h new file mode 100644 index 00000000..000467d2 --- /dev/null +++ b/include/nnet/Visitor/MatmulTransposeMutator.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class MatmulTransposeMutator : public Mutator { + Derivator &derivator; + + public: + MatmulTransposeMutator(Derivator &derivator) + : Mutator(1), derivator(derivator) {} + VecExpr transpose(const Tensor &tensor); + + private: + Tensor transposeInput(const Tensor &tensor); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/MergeMemboundMutator.h b/include/nnet/Visitor/MergeMemboundMutator.h new file mode 100644 index 00000000..786103e1 --- /dev/null +++ b/include/nnet/Visitor/MergeMemboundMutator.h @@ -0,0 +1,20 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class MergeMemboundMutator : public Mutator { + VecExpr kernels; + int curDepth; // from the last one to the first one + Expr visit_(const Tensor &c) override; + // FIXME: duplicate code + Expr rule4StageMerging(Expr &rCur, bool mergeStageWithCalc); + bool checkEmpty(); + + public: + MergeMemboundMutator(const VecExpr &kernels) + : Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {} + Expr merge(bool allowEmptyMembound = false); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/PatternMatcher.h b/include/nnet/Visitor/PatternMatcher.h new file mode 100644 index 00000000..6765a4e0 --- /dev/null +++ b/include/nnet/Visitor/PatternMatcher.h @@ -0,0 +1,43 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +/** + * @brief Since the output positions of operators always start from 0, we have + * to offset them if the the boundary expression of is not 0. + */ +class PatternMatcher : public Functor { + private: + Derivator &derivator; + bool hasNonZeroRange; + const RangeOp originalCur; + + public: + PatternMatcher(Derivator &derivator, const RangeOp &cur); + /** + * @brief Get the Cur whose loop vars are all offset to [0, x). Since + * operator outputs start from 0, RangeOp has to be aligned. + */ + RangeOp getOffsetCur(); + /** + * @brief Add outer RangeOp to map the original positions to the new + * positions staring from 0. + * + * @param exprs Tensors from matched exprs + */ + VecExpr applyWrapper(const VecExpr &exprs); + + VecExpr matchWithPattern(const RangeOp &rangeOp, const Pattern &pattern); + + private: + VecExpr matchKernel(const Pattern &pattern, const RangeOp &rangeOp, + IteratorTable &exprIT); + // get reverse tensor and iterator map ([pattern tensor/iter ID] -> + // real) + Expr matchKernelWithTensorMap(const Pattern &pattern, + const RangeOp &rangeOp, + IteratorTable &exprIT); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/RangeMagnifyVisitor.h b/include/nnet/Visitor/RangeMagnifyVisitor.h new file mode 100644 index 00000000..f6fa03c1 --- /dev/null +++ b/include/nnet/Visitor/RangeMagnifyVisitor.h @@ -0,0 +1,25 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class RangeMagnifyVisitor : public Mutator { + vector newSumVarRanges; + RangeOp newRangeOp; + + public: + RangeMagnifyVisitor() : Mutator(0) {} + Expr visit_(const RangeOp &c) override; + Expr visit_(const Subscript &c) override; + /** + * @brief + * + * @param root + * @param _newSumVarRanges + * @return RangeOp nullptr if failed to magnify + */ + RangeOp magnify(const RangeOp &root, + const vector &_newSumVarRanges); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/RangeRelaxFunctor.h b/include/nnet/Visitor/RangeRelaxFunctor.h new file mode 100644 index 00000000..2dd80c2f --- /dev/null +++ b/include/nnet/Visitor/RangeRelaxFunctor.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +using RangeMap = PtrMap; +class RangeRelaxFunctor : public Functor { + RangeOp rangeOp; + + public: + RangeRelaxFunctor(RangeOp _rangeOp) : Functor(false), rangeOp(_rangeOp) {} + RangeMap visit_(const BinaryOp &c) override; + RangeMap visit_(const RangeOp &c) override; + RangeMap visit_(const Subscript &c) override; + RangeMap intersectRangeMaps(const RangeMap &a, const RangeMap &b); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/ReplaceNodeMutator.h b/include/nnet/Visitor/ReplaceNodeMutator.h new file mode 100644 index 00000000..bdb7fb63 --- /dev/null +++ b/include/nnet/Visitor/ReplaceNodeMutator.h @@ -0,0 +1,20 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +// Replace node according to its address the summand +// Only subscript and tensor are supported now. +class ReplaceNodeMutator : public Mutator { + int nSubscripts = 0; + ExprNode *target; + Expr replacement; + + public: + ReplaceNodeMutator() : Mutator(0) {} + Expr visit_(const Subscript &c) override; + Expr visit_(const Tensor &c) override; + Expr replace(const Expr &root, ExprNode *_target, const Expr &_replace); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/ReplaceVariable.h b/include/nnet/Visitor/ReplaceVariable.h new file mode 100644 index 00000000..20efbc4f --- /dev/null +++ b/include/nnet/Visitor/ReplaceVariable.h @@ -0,0 +1,33 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +class ReplaceVariable : public Mutator { + VecExpr patterns, replacements; + map patternHash; + + public: + ReplaceVariable(Expr _pattern, Expr _replacement) : Mutator(false) { + set({_pattern}, {_replacement}); + } + ReplaceVariable(const map> &mapping) + : Mutator(false) { + VecExpr _patterns, _replacements; + for (const auto &[_, v] : mapping) { + _patterns.emplace_back(v.first); + _replacements.emplace_back(v.second); + } + set(_patterns, _replacements); + } + Expr visit_(const BinaryOp &c) override; + // NOT recur to the next stage + Expr visit_(const RangeOp &c) override; + Expr visit_(const Var &c) override; + + private: + void set(VecExpr _pattern, VecExpr _replacement); + Expr match(const Expr &c); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/Serializer.h b/include/nnet/Visitor/Serializer.h new file mode 100644 index 00000000..3568b152 --- /dev/null +++ b/include/nnet/Visitor/Serializer.h @@ -0,0 +1,52 @@ +#pragma once +#include "nlohmann/json_fwd.hpp" +#include "nnet/visitor.h" +#include + +namespace nnet { + +class Serializer : public Functor { + using json = nlohmann::ordered_json; + + private: + static constexpr int VERSION{1}; + std::unique_ptr jPtr; + json &j; + static int id; + + string visit_(const Constant &c) override; + string visit_(const BinaryOp &c) override; + string visit_(const RangeOp &c) override; + string visit_(const Subscript &c) override; + string visit_(const Var &c) override; + string visit_(const Tensor &c) override; + string dispatchRoutine(const Routine &c); + + Expr buildExprTree(string key); + Routine buildRoutine(string key); + + public: + Serializer(int _verobse = 0); + virtual ~Serializer(); + + /** + * @brief Serialize the given expression to json file + * + * @param expr The expression to be serialized + * @param filePath The path of json file to be output + * @param msg Message of derivation + * @return bool Whether the serialization succeed + */ + bool serialize(const Expr &expr, const string &filePath, + const string &msg = ""); + + /** + * @brief Deserialize the given json file to expression + * + * @param filePath The path to file to be deserialized + * @return Expression deserialized from the given json file + */ + Expr deserialize(const string &filePath); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/SimplifyExprVisitor.h b/include/nnet/Visitor/SimplifyExprVisitor.h new file mode 100644 index 00000000..e8ed459c --- /dev/null +++ b/include/nnet/Visitor/SimplifyExprVisitor.h @@ -0,0 +1,38 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +// Simplify a index expression tree +class SimplifyExprVisitor : public Functor stride)> { + private: + SubexprSride subexprStride; + int constant; + PtrMap strides; // [var]=strides + + map, int, RefValueLess>> divStrides, + modStrides; // 3*(i%8): []=3 + + // For divde and modulo with expr as dividend: 3*((i+1)%8): []=3 + map, int, RefAddrLess>> divExprStrides, + modExprStrides; + + public: + SimplifyExprVisitor() : Functor(0) {} + void visit_(const BinaryOp &c, optional stride) override; + void visit_(const Var &c, optional stride) override; + void visit_(const Constant &c, optional stride) override; + PtrMap getStrides(const Expr &expr); + // TODO [refactor]: move this to SimplifyFormulaMutator as a member func + // this class should be get coefficients in a expr + Expr simplify(const Expr &expr); + int getConstant(const Expr &expr); + pair, int> getStridesConstant(const Expr &expr); + optional getExprRange(const Expr &expr, const RangeOp &rangeOp); + PtrMap getStrides() { return strides; } + const auto &getDivStrides() { return divStrides; } + const auto &getModStrides() { return modStrides; } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/SimplifyFormulaMutator.h b/include/nnet/Visitor/SimplifyFormulaMutator.h new file mode 100644 index 00000000..48f92870 --- /dev/null +++ b/include/nnet/Visitor/SimplifyFormulaMutator.h @@ -0,0 +1,18 @@ +#pragma once +#include "nnet/Visitor/StrideVisitor.h" +#include "nnet/visitor.h" + +namespace nnet { + +// Simplify all indexes in subscripts in an expression tree +class SimplifyFormulaMutator : public Mutator { + int nSubscripts = 0; + + public: + SimplifyFormulaMutator() : Mutator(0) {} + Expr visit_(const Subscript &c) override; + // Expr visit_(const BinaryOp &c) override; + Expr simplify(const Expr &expr); +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/Visitor/StrideVisitor.h b/include/nnet/Visitor/StrideVisitor.h new file mode 100644 index 00000000..dca17569 --- /dev/null +++ b/include/nnet/Visitor/StrideVisitor.h @@ -0,0 +1,38 @@ +#pragma once +#include "nnet/visitor.h" + +namespace nnet { + +using SubexprSride = map>; +class StrideVisitor : public Functor(void)> { + private: + SubexprSride subexprStride; + + public: + StrideVisitor(int _verobse = 0) : Functor(_verobse) {} + optional visit_(const BinaryOp &c) override; + optional visit_(const Subscript &c) override; + optional visit_(const Var &c) override; + optional visit_(const Constant &c) override; + // void visit_(const Tensor &c, const Tensor &tensor) override; + + auto getFormulaStride(const RangeOp &e) { + subexprStride.clear(); + // get the location and stride of each iterator + auto mulOp = as(e->getSummand()); + // TODO [feature]: support complex index exprs + if (!mulOp || mulOp->getOpType() != OpType::Mul) + nnet_unimplemented_continue(); + dispatch(mulOp->getLhs()); + dispatch(mulOp->getRhs()); + return subexprStride; + } + + [[nodiscard]] auto getExprStride(const Expr &e) { + subexprStride.clear(); + dispatch(e); + return subexprStride; + } +}; + +} // namespace nnet \ No newline at end of file diff --git a/include/nnet/common.h b/include/nnet/common.h new file mode 100644 index 00000000..567c6a27 --- /dev/null +++ b/include/nnet/common.h @@ -0,0 +1,77 @@ +#pragma once +#include "dbg.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nnet { +using std::dynamic_pointer_cast; +using std::endl; +using std::list; +using std::make_pair; +using std::make_shared; +using std::make_tuple; +using std::map; +using std::max; +using std::min; +using std::nullopt; +using std::optional; +using std::pair; +using std::set; +using std::shared_ptr; +using std::string; +using std::tie; +using std::to_string; +using std::tuple; +using std::unique_ptr; +using std::unordered_map; +template using uset = std::unordered_set; +using std::vector; +using std::weak_ptr; + +// Aliases +using dtype = float; +using HashType = int; + +template struct ptr_less { + bool operator()(const T &lhs, const T &rhs) const { return *lhs < *rhs; } +}; + +template struct ptr_hash { + size_t operator()(const T &lhs) const { + return std::hash()(*lhs); + } +}; + +template struct ptr_equal { + bool operator()(const T &lhs, const T &rhs) const { return *lhs == *rhs; } +}; + +static inline HashType genhash(HashType a, HashType b) { + return (a * 10007 + b + 12345) % 1000000007; +} + +static inline HashType genhash(string s) { + HashType ret = 0; + for (auto c : s) + ret = genhash(ret, c); + return ret; +} + +#define nnet_unimplemented_halt() \ + { assert(!"Unimplemented"); } + +#define nnet_unimplemented_continue() \ + { dbg("Unimplemented"); } + +#define nnet_assert(expr, msg) assert(((void)(msg), (expr))) + +std::string pointer_to_hex(void *i); +} // namespace nnet diff --git a/include/nnet/dbg.h b/include/nnet/dbg.h new file mode 100644 index 00000000..f5894e4b --- /dev/null +++ b/include/nnet/dbg.h @@ -0,0 +1,858 @@ +/***************************************************************************** + + dbg(...) macro + +License (MIT): + + Copyright (c) 2019 David Peter + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to + deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + +*****************************************************************************/ + +#ifndef DBG_MACRO_DBG_H +#define DBG_MACRO_DBG_H + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) +#define DBG_MACRO_UNIX +#elif defined(_MSC_VER) +#define DBG_MACRO_WINDOWS +#endif + +// #ifndef DBG_MACRO_NO_WARNING +// #pragma message("WARNING: the 'dbg.h' header is included in your code base") +// #endif // DBG_MACRO_NO_WARNING + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef DBG_MACRO_UNIX +#include +#endif + +#if __cplusplus >= 201703L +#define DBG_MACRO_CXX_STANDARD 17 +#elif __cplusplus >= 201402L +#define DBG_MACRO_CXX_STANDARD 14 +#else +#define DBG_MACRO_CXX_STANDARD 11 +#endif + +#if DBG_MACRO_CXX_STANDARD >= 17 +#include +#include +#endif + +namespace dbg { + +#ifdef DBG_MACRO_UNIX +inline bool isColorizedOutputEnabled() { return isatty(fileno(stderr)); } +#else +inline bool isColorizedOutputEnabled() { return true; } +#endif + +struct time {}; + +namespace pretty_function { + +// Compiler-agnostic version of __PRETTY_FUNCTION__ and constants to +// extract the template argument in `type_name_impl` + +#if defined(__clang__) +#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__ +static constexpr size_t PREFIX_LENGTH = + sizeof("const char *dbg::type_name_impl() [T = ") - 1; +static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1; +#elif defined(__GNUC__) && !defined(__clang__) +#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__ +static constexpr size_t PREFIX_LENGTH = + sizeof("const char* dbg::type_name_impl() [with T = ") - 1; +static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1; +#elif defined(_MSC_VER) +#define DBG_MACRO_PRETTY_FUNCTION __FUNCSIG__ +static constexpr size_t PREFIX_LENGTH = + sizeof("const char *__cdecl dbg::type_name_impl<") - 1; +static constexpr size_t SUFFIX_LENGTH = sizeof(">(void)") - 1; +#else +#error "This compiler is currently not supported by dbg_macro." +#endif + +} // namespace pretty_function + +// Formatting helpers + +template struct print_formatted { + static_assert(std::is_integral::value, + "Only integral types are supported."); + + print_formatted(T value, int numeric_base) + : inner(value), base(numeric_base) {} + + operator T() const { return inner; } + + const char *prefix() const { + switch (base) { + case 8: + return "0o"; + case 16: + return "0x"; + case 2: + return "0b"; + default: + return ""; + } + } + + T inner; + int base; +}; + +template print_formatted hex(T value) { + return print_formatted{value, 16}; +} + +template print_formatted oct(T value) { + return print_formatted{value, 8}; +} + +template print_formatted bin(T value) { + return print_formatted{value, 2}; +} + +// Implementation of 'type_name()' + +template const char *type_name_impl() { + return DBG_MACRO_PRETTY_FUNCTION; +} + +template struct type_tag {}; + +template +std::string get_type_name(type_tag) { + namespace pf = pretty_function; + + std::string type = type_name_impl(); + return type.substr(pf::PREFIX_LENGTH, + type.size() - pf::PREFIX_LENGTH - pf::SUFFIX_LENGTH); +} + +template std::string type_name() { + if (std::is_volatile::value) { + if (std::is_pointer::value) { + return type_name::type>() + + " volatile"; + } else { + return "volatile " + + type_name::type>(); + } + } + if (std::is_const::value) { + if (std::is_pointer::value) { + return type_name::type>() + " const"; + } else { + return "const " + type_name::type>(); + } + } + if (std::is_pointer::value) { + return type_name::type>() + "*"; + } + if (std::is_lvalue_reference::value) { + return type_name::type>() + "&"; + } + if (std::is_rvalue_reference::value) { + return type_name::type>() + "&&"; + } + return get_type_name(type_tag{}); +} + +inline std::string get_type_name(type_tag) { return "short"; } + +inline std::string get_type_name(type_tag) { + return "unsigned short"; +} + +inline std::string get_type_name(type_tag) { return "long"; } + +inline std::string get_type_name(type_tag) { + return "unsigned long"; +} + +inline std::string get_type_name(type_tag) { + return "std::string"; +} + +template +std::string get_type_name(type_tag>>) { + return "std::vector<" + type_name() + ">"; +} + +template +std::string get_type_name(type_tag>) { + return "std::pair<" + type_name() + ", " + type_name() + ">"; +} + +template std::string type_list_to_string() { + std::string result; + auto unused = {(result += type_name() + ", ", 0)..., 0}; + static_cast(unused); + +#if DBG_MACRO_CXX_STANDARD >= 17 + if constexpr (sizeof...(T) > 0) { +#else + if (sizeof...(T) > 0) { +#endif + result.pop_back(); + result.pop_back(); + } + return result; +} + +template std::string get_type_name(type_tag>) { + return "std::tuple<" + type_list_to_string() + ">"; +} + +template +inline std::string get_type_name(type_tag>) { + return type_name(); +} + +// Implementation of 'is_detected' to specialize for container-like types + +namespace detail_detector { + +struct nonesuch { + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const &) = delete; + void operator=(nonesuch const &) = delete; +}; + +template using void_t = void; + +template class Op, + class... Args> +struct detector { + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; +}; + +} // namespace detail_detector + +template