From e6101b03366b721ef1ac55dead8c7fdc343ec451 Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Sun, 31 Jul 2022 21:43:26 +0800 Subject: [PATCH] Add: graph, tensor, and operator --- .clang-format | 90 +++++ .cmake-format.json | 311 +++++++++++++++++ .gitignore | 5 + .gitmodules | 9 + 3rd-party/googletest | 1 + 3rd-party/nlohmann_json_cmake_fetchcontent | 1 + 3rd-party/pybind11 | 1 + CMakeLists.txt | 73 ++++ include/core/common.h | 32 ++ include/core/graph.h | 36 ++ include/core/object.h | 45 +++ include/core/operator.h | 21 ++ include/core/ref.h | 27 ++ include/core/tensor.h | 368 +++++++++++++++++++++ include/test.h | 3 + src/core/graph.cc | 15 + src/core/operator.cc | 11 + src/core/tensor.cc | 8 + test/core/test_graph.cc | 12 + 19 files changed, 1069 insertions(+) create mode 100644 .clang-format create mode 100644 .cmake-format.json create mode 100644 .gitmodules create mode 160000 3rd-party/googletest create mode 160000 3rd-party/nlohmann_json_cmake_fetchcontent create mode 160000 3rd-party/pybind11 create mode 100644 CMakeLists.txt create mode 100644 include/core/common.h create mode 100644 include/core/graph.h create mode 100644 include/core/object.h create mode 100644 include/core/operator.h create mode 100644 include/core/ref.h create mode 100644 include/core/tensor.h create mode 100644 include/test.h create mode 100644 src/core/graph.cc create mode 100644 src/core/operator.cc create mode 100644 src/core/tensor.cc create mode 100644 test/core/test_graph.cc diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..d8a7dfe7 --- /dev/null +++ b/.clang-format @@ -0,0 +1,90 @@ +--- +Language: Cpp +# BasedOnStyle: LLVM +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: false +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: false +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IndentCaseLabels: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Right +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... + diff --git a/.cmake-format.json b/.cmake-format.json new file mode 100644 index 00000000..ba3498b9 --- /dev/null +++ b/.cmake-format.json @@ -0,0 +1,311 @@ +{ + "_help_parse": "Options affecting listfile parsing", + "parse": { + "_help_additional_commands": [ + "Specify structure for custom cmake functions" + ], + "additional_commands": { + "foo": { + "flags": [ + "BAR", + "BAZ" + ], + "kwargs": { + "HEADERS": "*", + "SOURCES": "*", + "DEPENDS": "*" + } + } + }, + "_help_override_spec": [ + "Override configurations per-command where available" + ], + "override_spec": {}, + "_help_vartags": [ + "Specify variable tags." + ], + "vartags": [], + "_help_proptags": [ + "Specify property tags." + ], + "proptags": [] + }, + "_help_format": "Options affecting formatting.", + "format": { + "_help_disable": [ + "Disable formatting entirely, making cmake-format a no-op" + ], + "disable": false, + "_help_line_width": [ + "How wide to allow formatted cmake files" + ], + "line_width": 1500, + "_help_tab_size": [ + "How many spaces to tab for indent" + ], + "tab_size": 2, + "_help_use_tabchars": [ + "If true, lines are indented using tab characters (utf-8", + "0x09) instead of space characters (utf-8 0x20).", + "In cases where the layout would require a fractional tab", + "character, the behavior of the fractional indentation is", + "governed by " + ], + "use_tabchars": false, + "_help_fractional_tab_policy": [ + "If is True, then the value of this variable", + "indicates how fractional indentions are handled during", + "whitespace replacement. If set to 'use-space', fractional", + "indentation is left as spaces (utf-8 0x20). If set to", + "`round-up` fractional indentation is replaced with a single", + "tab character (utf-8 0x09) effectively shifting the column", + "to the next tabstop" + ], + "fractional_tab_policy": "use-space", + "_help_max_subgroups_hwrap": [ + "If an argument group contains more than this many sub-groups", + "(parg or kwarg groups) then force it to a vertical layout." + ], + "max_subgroups_hwrap": 2, + "_help_max_pargs_hwrap": [ + "If a positional argument group contains more than this many", + "arguments, then force it to a vertical layout." + ], + "max_pargs_hwrap": 6, + "_help_max_rows_cmdline": [ + "If a cmdline positional group consumes more than this many", + "lines without nesting, then invalidate the layout (and nest)" + ], + "max_rows_cmdline": 2, + "_help_separate_ctrl_name_with_space": [ + "If true, separate flow control names from their parentheses", + "with a space" + ], + "separate_ctrl_name_with_space": false, + "_help_separate_fn_name_with_space": [ + "If true, separate function names from parentheses with a", + "space" + ], + "separate_fn_name_with_space": false, + "_help_dangle_parens": [ + "If a statement is wrapped to more than one line, than dangle", + "the closing parenthesis on its own line." + ], + "dangle_parens": false, + "_help_dangle_align": [ + "If the trailing parenthesis must be 'dangled' on its on", + "line, then align it to this reference: `prefix`: the start", + "of the statement, `prefix-indent`: the start of the", + "statement, plus one indentation level, `child`: align to", + "the column of the arguments" + ], + "dangle_align": "prefix", + "_help_min_prefix_chars": [ + "If the statement spelling length (including space and", + "parenthesis) is smaller than this amount, then force reject", + "nested layouts." + ], + "min_prefix_chars": 4, + "_help_max_prefix_chars": [ + "If the statement spelling length (including space and", + "parenthesis) is larger than the tab width by more than this", + "amount, then force reject un-nested layouts." + ], + "max_prefix_chars": 10, + "_help_max_lines_hwrap": [ + "If a candidate layout is wrapped horizontally but it exceeds", + "this many lines, then reject the layout." + ], + "max_lines_hwrap": 2, + "_help_line_ending": [ + "What style line endings to use in the output." + ], + "line_ending": "unix", + "_help_command_case": [ + "Format command names consistently as 'lower' or 'upper' case" + ], + "command_case": "canonical", + "_help_keyword_case": [ + "Format keywords consistently as 'lower' or 'upper' case" + ], + "keyword_case": "unchanged", + "_help_always_wrap": [ + "A list of command names which should always be wrapped" + ], + "always_wrap": [], + "_help_enable_sort": [ + "If true, the argument lists which are known to be sortable", + "will be sorted lexicographicall" + ], + "enable_sort": true, + "_help_autosort": [ + "If true, the parsers may infer whether or not an argument", + "list is sortable (without annotation)." + ], + "autosort": false, + "_help_require_valid_layout": [ + "By default, if cmake-format cannot successfully fit", + "everything into the desired linewidth it will apply the", + "last, most aggressive attempt that it made. If this flag is", + "True, however, cmake-format will print error, exit with non-", + "zero status code, and write-out nothing" + ], + "require_valid_layout": false, + "_help_layout_passes": [ + "A dictionary mapping layout nodes to a list of wrap", + "decisions. See the documentation for more information." + ], + "layout_passes": {} + }, + "_help_markup": "Options affecting comment reflow and formatting.", + "markup": { + "_help_bullet_char": [ + "What character to use for bulleted lists" + ], + "bullet_char": "*", + "_help_enum_char": [ + "What character to use as punctuation after numerals in an", + "enumerated list" + ], + "enum_char": ".", + "_help_first_comment_is_literal": [ + "If comment markup is enabled, don't reflow the first comment", + "block in each listfile. Use this to preserve formatting of", + "your copyright/license statements." + ], + "first_comment_is_literal": false, + "_help_literal_comment_pattern": [ + "If comment markup is enabled, don't reflow any comment block", + "which matches this (regex) pattern. Default is `None`", + "(disabled)." + ], + "literal_comment_pattern": ".*", + "_help_fence_pattern": [ + "Regular expression to match preformat fences in comments", + "default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``" + ], + "fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$", + "_help_ruler_pattern": [ + "Regular expression to match rulers in comments default=", + "``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``" + ], + "ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$", + "_help_explicit_trailing_pattern": [ + "If a comment line matches starts with this pattern then it", + "is explicitly a trailing comment for the preceding argument.", + "Default is '#<'" + ], + "explicit_trailing_pattern": "#<", + "_help_hashruler_min_length": [ + "If a comment line starts with at least this many consecutive", + "hash characters, then don't lstrip() them off. This allows", + "for lazy hash rulers where the first hash char is not", + "separated by space" + ], + "hashruler_min_length": 10, + "_help_canonicalize_hashrulers": [ + "If true, then insert a space between the first hash char and", + "remaining hash chars in a hash ruler, and normalize its", + "length to fill the column" + ], + "canonicalize_hashrulers": true, + "_help_enable_markup": [ + "enable comment markup parsing and reflow" + ], + "enable_markup": true + }, + "_help_lint": "Options affecting the linter", + "lint": { + "_help_disabled_codes": [ + "a list of lint codes to disable" + ], + "disabled_codes": [], + "_help_function_pattern": [ + "regular expression pattern describing valid function names" + ], + "function_pattern": "[0-9a-z_]+", + "_help_macro_pattern": [ + "regular expression pattern describing valid macro names" + ], + "macro_pattern": "[0-9A-Z_]+", + "_help_global_var_pattern": [ + "regular expression pattern describing valid names for", + "variables with global (cache) scope" + ], + "global_var_pattern": "[A-Z][0-9A-Z_]+", + "_help_internal_var_pattern": [ + "regular expression pattern describing valid names for", + "variables with global scope (but internal semantic)" + ], + "internal_var_pattern": "_[A-Z][0-9A-Z_]+", + "_help_local_var_pattern": [ + "regular expression pattern describing valid names for", + "variables with local scope" + ], + "local_var_pattern": "[a-z][a-z0-9_]+", + "_help_private_var_pattern": [ + "regular expression pattern describing valid names for", + "privatedirectory variables" + ], + "private_var_pattern": "_[0-9a-z_]+", + "_help_public_var_pattern": [ + "regular expression pattern describing valid names for public", + "directory variables" + ], + "public_var_pattern": "[A-Z][0-9A-Z_]+", + "_help_argument_var_pattern": [ + "regular expression pattern describing valid names for", + "function/macro arguments and loop variables." + ], + "argument_var_pattern": "[a-z][a-z0-9_]+", + "_help_keyword_pattern": [ + "regular expression pattern describing valid names for", + "keywords used in functions or macros" + ], + "keyword_pattern": "[A-Z][0-9A-Z_]+", + "_help_max_conditionals_custom_parser": [ + "In the heuristic for C0201, how many conditionals to match", + "within a loop in before considering the loop a parser." + ], + "max_conditionals_custom_parser": 2, + "_help_min_statement_spacing": [ + "Require at least this many newlines between statements" + ], + "min_statement_spacing": 1, + "_help_max_statement_spacing": [ + "Require no more than this many newlines between statements" + ], + "max_statement_spacing": 2, + "max_returns": 6, + "max_branches": 12, + "max_arguments": 5, + "max_localvars": 15, + "max_statements": 50 + }, + "_help_encode": "Options affecting file encoding", + "encode": { + "_help_emit_byteorder_mark": [ + "If true, emit the unicode byte-order mark (BOM) at the start", + "of the file" + ], + "emit_byteorder_mark": false, + "_help_input_encoding": [ + "Specify the encoding of the input file. Defaults to utf-8" + ], + "input_encoding": "utf-8", + "_help_output_encoding": [ + "Specify the encoding of the output file. Defaults to utf-8.", + "Note that cmake only claims to support utf-8 so be careful", + "when using anything else" + ], + "output_encoding": "utf-8" + }, + "_help_misc": "Miscellaneous configurations options.", + "misc": { + "_help_per_command": [ + "A dictionary containing any per-command configuration", + "overrides. Currently only `command_case` is supported." + ], + "per_command": {} + } +} diff --git a/.gitignore b/.gitignore index 259148fa..6d512811 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,8 @@ *.exe *.out *.app + +build/ +build_debug/ + +.vscode/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..2f905b6a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "3rd-party/pybind11"] + path = 3rd-party/pybind11 + url = git@github.com:pybind/pybind11.git +[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"] + path = 3rd-party/nlohmann_json_cmake_fetchcontent + url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git +[submodule "3rd-party/googletest"] + path = 3rd-party/googletest + url = git@github.com:google/googletest.git diff --git a/3rd-party/googletest b/3rd-party/googletest new file mode 160000 index 00000000..e2239ee6 --- /dev/null +++ b/3rd-party/googletest @@ -0,0 +1 @@ +Subproject commit e2239ee6043f73722e7aa812a459f54a28552929 diff --git a/3rd-party/nlohmann_json_cmake_fetchcontent b/3rd-party/nlohmann_json_cmake_fetchcontent new file mode 160000 index 00000000..6aebf092 --- /dev/null +++ b/3rd-party/nlohmann_json_cmake_fetchcontent @@ -0,0 +1 @@ +Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756 diff --git a/3rd-party/pybind11 b/3rd-party/pybind11 new file mode 160000 index 00000000..1e3400b6 --- /dev/null +++ b/3rd-party/pybind11 @@ -0,0 +1 @@ +Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..dd37869a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,73 @@ +# TODO: check the minimum cmake version +cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP) +include(CMakeDependentOption) +project(InfiniTensor C CXX) + +# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. +option(BUILD_TEST "Build tests" OFF) +cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) +cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF) +cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF) + +set(DEFAULT_BUILD_TYPE "RelWithDebInfo") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion + +find_package( + Python + COMPONENTS Interpreter Development + REQUIRED) +find_package(CUDA REQUIRED) +# OpenMP +find_package(OpenMP) +if(OpenMP_C_FOUND) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") +endif() +if(OpenMP_CXX_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +endif() + +# add_subdirectory(3rd-party/pybind11) +# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) + +include_directories(include) +# include_directories(3rd-party/pybind11/include) +# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) + +if(BUILD_TEST) + set(BUILD_GMOCK off) + set(INSTALL_GTEST off) + add_subdirectory(3rd-party/googletest) + include_directories(3rd-party/googletest/googletest/include) +endif() + +file(GLOB_RECURSE SRC src/*.cc src/*.cu) +# file(GLOB_RECURSE TEST test/*.cc) +# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc) +# list(REMOVE_ITEM SRC ${TEST} ${FFI}) + +add_library(InfiniTensor SHARED ${SRC}) +# Target +# cuda_add_library(it SHARED ${SRC}) +# cuda_add_cublas_to_target(it) # cublas +# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed) + +# # Python bindings +# pybind11_add_module(infini MODULE ${FFI}) +# target_link_libraries(infini PRIVATE infini_cpp) + +if(BUILD_TEST) + enable_testing() + # Build all tests file( GLOB TEST_SOURCES test/test_sg2bmm.cc ) + file(GLOB_RECURSE TEST_SOURCES test/*.cc) + foreach(testsourcefile ${TEST_SOURCES}) + get_filename_component(testname ${testsourcefile} NAME_WE) + add_executable(${testname} ${testsourcefile}) + target_link_libraries(${testname} InfiniTensor gtest_main gtest) + add_test(NAME ${testname} COMMAND ${testname}) + endforeach(testsourcefile ${TEST_SOURCES}) +endif() diff --git a/include/core/common.h b/include/core/common.h new file mode 100644 index 00000000..f473db9e --- /dev/null +++ b/include/core/common.h @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace it { +using std::list; +using std::map; +using std::nullopt; +using std::optional; +using std::pair; +using std::set; +using std::string; +using std::tie; +using std::to_string; +using std::tuple; +using std::unordered_map; +using std::vector; + +// Aliases +using dtype = float; + +} // namespace it diff --git a/include/core/graph.h b/include/core/graph.h new file mode 100644 index 00000000..6f2f6495 --- /dev/null +++ b/include/core/graph.h @@ -0,0 +1,36 @@ +#pragma once +#include "core/operator.h" +#include "core/tensor.h" + +namespace it { + +class GraphNode : public Object { + protected: + TensorVec tensors; + TensorVec inputs; + TensorVec outputs; + OpVec ops; + + public: + // Graph(OpVec oplist); + string toString() const override; + + void addOp(Operator op) { ops.push_back(op); }; + const TensorVec &getTensors() const { return tensors; } + const TensorVec &getInputs() const { return inputs; } + const TensorVec &getOutputs() const { return outputs; } + const OpVec &getOperators() const { return ops; } + // TensorVec &getInputs(); + // TensorVec &getOutputs(); + + void updateConnection(); + + // TODO + // bool compute(); + + // TODO: move to another class + // bool exportOnnx(const char *path); + // bool importOnnx(const char *net); +}; + +} // namespace it \ No newline at end of file diff --git a/include/core/object.h b/include/core/object.h new file mode 100644 index 00000000..2ce7db01 --- /dev/null +++ b/include/core/object.h @@ -0,0 +1,45 @@ +#pragma once +#include "core/common.h" + +namespace it { + +using GuidBaseType = int; + +class Guid { + private: + GuidBaseType guid; + + private: + GuidBaseType generateGuid() { + static GuidBaseType guidCnt = 0; + return ++guidCnt; + } + + public: + Guid() { guid = generateGuid(); } + Guid(const Guid &rhs) { guid = generateGuid(); } + Guid &operator=(const Guid &rhs) { + guid = generateGuid(); + return *this; + } + + operator GuidBaseType() const { return guid; } +}; + +class Object { + protected: + Guid guid; + + public: + virtual ~Object(){}; + virtual string toString() const = 0; + void print() { std::cout << toString() << std::endl; } + Guid getGuid() const { return guid; } +}; + +inline std::ostream &operator<<(std::ostream &os, const Object &obj) { + os << obj.toString(); + return os; +} + +} \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h new file mode 100644 index 00000000..467c5806 --- /dev/null +++ b/include/core/operator.h @@ -0,0 +1,21 @@ +#pragma once +#include "core/tensor.h" + +namespace it { + +class OperatorNode : public Object { + protected: + // OpType type; + TensorVec inputs; + TensorVec outputs; + // vector> predecessors; + // vector> successors; + public: + OperatorNode(TensorVec inputs, TensorVec outputs) + : inputs(inputs), outputs(outputs) {} + string toString() const override; + // Operator(TensorVec inputs) : inputs(inputs) {} + + virtual ~OperatorNode() {} +}; +} // namespace it \ No newline at end of file diff --git a/include/core/ref.h b/include/core/ref.h new file mode 100644 index 00000000..7799d02a --- /dev/null +++ b/include/core/ref.h @@ -0,0 +1,27 @@ +#pragma once +#include "common.h" +#include // hash +#include +#include + +namespace it { + +template using Ref = std::shared_ptr; +template using WRef = std::weak_ptr; + +template struct is_ref : std::false_type {}; +template struct is_ref> : std::true_type {}; +template struct is_ref> : std::true_type {}; + +template Ref make_ref(Params &&...params) { + static_assert(is_ref::value == false, "Ref should not be nested"); + return std::make_shared(std::forward(params)...); +} + +template > * = nullptr> +Ref as(const Ref &ref) { + return std::dynamic_pointer_cast(ref); +} + +} // namespace it \ No newline at end of file diff --git a/include/core/tensor.h b/include/core/tensor.h new file mode 100644 index 00000000..a7e49fcb --- /dev/null +++ b/include/core/tensor.h @@ -0,0 +1,368 @@ +#pragma once +#include "core/object.h" +#include "core/ref.h" + +namespace it { + +// class Tensor; +class TensorBaseNode; +class OperatorNode; +class GraphNode; + +using TensorBase = Ref; +using Operator = Ref; +using Graph = Ref; + +using TensorVec = vector; +using OpVec = vector; + +// using TensorMap = std::map; +// using OpMap = std::map; +using VType = uint32_t; +// using SplittingPoints = std::vector>; + +class TensorBaseNode : public Object { + public: + enum DataType { + Float32, + Int32, + }; + + // enum TensorType { + // Input, + // Weight, + // Invalid, + // NotCounted, + // }; + + // // TODO: is more compute state needed? + // enum ComputeState { + // NotComputed, + // // Allocated, + // // Initialized, + // // ComputedPartial, + // ComputedFull, + // }; + + private: + int hid; + // uint64_t hash; + // Shape shape; + int dim; + + vector> inputOf; + WRef outputOf; + Ref data; + DataType dtype; + // ComputeState computed; + // static int random_seed[256 * 16]; + // static bool random_inited; + + public: + // Tensor(TensorType type = Input, DataType dtype = Float32) + // : guid(generateGuid()), hash(generateHash()), outputOf(nullptr), + // data(nullptr), dtype(dtype), type(type), computed(NotComputed) {} + // Tensor(const Dim &dims, TensorType type = Input, DataType dtype = + // Float32) + // : guid(generateGuid()), hash(generateHash()), dims(dims), + // outputOf(nullptr), data(nullptr), dtype(dtype), type(type), + // computed(NotComputed) { + // itInit(); + // } + // Tensor(const Tensor &rhs) : Tensor(rhs.dims, rhs.type, rhs.dtype) { + // outputOf = nullptr; + // data = nullptr; + // hash = rhs.hash; + // dimPenalty = rhs.dimPenalty; + // itInit(); + // } + // Tensor(VType scalar, TensorType type = Weight, DataType dtype = Float32) + // : guid(generateGuid()), hash(generateHash()), outputOf(nullptr), + // data(nullptr), dtype(dtype), type(type), computed(ComputedFull) { + // assert(size() == 1); + // dataMalloc(); + // data[0] = scalar; + // } + virtual ~TensorBaseNode() {} + string toString() const override; + + // // inputOf and outputOf will not be cloned + // Tensor *clone() { + // Tensor *t = new Tensor(*this); + // return t; + // } + + // void clone(Tensor *t) { + // dims = t->dims; + // dtype = t->dtype; + // type = t->type; + // hash = t->hash; + // dimPenalty = t->dimPenalty; + // } + + DataType getDType() const { return dtype; } + + // uint64_t getHash() const { return hash; } + + // void setInputOf(const OpVec &ops) { + // inputOf.clear(); + // for (const auto &op : ops) + // inputOf.emplace_back(op); + // } + // void addInputOf(Operator op) { inputOf.emplace_back(op); } + // void setOutputOf(Operator op) { outputOf = op; } + + // const OpVec &getInputOf() { return inputOf; } + // Operator *getOutputOf() { return outputOf; } + // std::pair getOutputOfWithIndex(); + + // bool dataMalloc() { + // if (data == nullptr) + // data = new VType[size()]; + // return data != nullptr; + // } + + // const Dim &getDims() const { return dims; } + // void setDims(const Dim &dms) { dims = dms; } + + // bool dataRand(int seed = 0) { + // if (data == nullptr) + // data = new VType[size()]; + // if (!random_inited) + // initFastrand(); + // // srand(seed); + // // faster rand generator; parallel + // size_t iEnd = size(); + // // std::cerr << "Init beginned " << std::endl; + // #pragma omp parallel for + // for (size_t i = 0; i < iEnd; ++i) + // data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) % + // 10000; + // // std::cerr << "Init finished" << std::endl; + // computed = ComputedFull; + // return true; + // } + + // bool setData(VType *dptr) { + // if (dptr == nullptr) + // return false; + // auto sz = size(); + // #pragma omp parallel for + // for (size_t i = 0; i < sz; ++i) + // data[i] = dptr[i]; + // computed = ComputedFull; + // return true; + // } + + // bool setScalar(VType val) { + // if (data == nullptr || !dims.empty()) + // return false; + // data[0] = val; + // return true; + // } + + // bool setData(const Dim &ds, VType val) { + // if (data == nullptr || ds.size() != dims.size()) + // return false; + // data[getOffset(ds)] = val; + // return true; + // } + + // bool setData(size_t pos, VType val) { + // if (data == nullptr || pos >= size()) + // return false; + // data[pos] = val; + // return true; + // } + + // VType getScalar() { return data == nullptr ? 0 : data[0]; } + + // VType getData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getData(size_t pos) { + // assert(data != nullptr); + // assert(pos < size()); + // return data[pos]; + // } + + // VType *getDataPtr() const { return data; } + + // size_t getOffset(const Dim &ds) { + // auto nDim = ds.size(); + // assert(dims.size() == nDim); + // if (ds.empty()) + // return 0; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[i] < 0 || ds[i] >= dims[i]) + // return (size_t)-1; + // size_t idx = ds[0]; + // size_t dm = 0; + // while (++dm < nDim) + // idx = idx * dims[dm] + ds[dm]; + // return idx; + // } + + // VType getBroadcastData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getBroadcastOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getBroadcastData(size_t pos) { + // assert(data != nullptr); + // return data[pos % size()]; + // } + + // size_t getBroadcastOffset(const Dim &ds) { + // assert(ds.size() >= dims.size()); + // auto nDim = dims.size(); + // auto nBroadcastDim = ds.size() - nDim; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >= + // dims[i]) + // return (size_t)-1; + // size_t idx = 0; + // for (size_t i = 0; i < nDim; ++i) + // idx = idx * dims[i] + ds[nBroadcastDim + i]; + // return idx; + // } + + // void itInit() { it = Dim(dims.size(), 0); } + + // void itReset() { + // itInit(); + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // it[i] = 0; + // } + + // bool itValid() { + // if (it.size() != dims.size()) + // return false; + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // if (it[i] >= dims[i]) + // return false; + // return true; + // } + + // const Dim &itGet() { return it; } + + // void itNext() { + // auto p = it.size() - 1; + // it[p] += 1; + // while (p >= 1) { + // if (it[p] == dims[p]) { + // it[p] = 0; + // it[--p] += 1; + // } else + // break; + // } + // } + + // size_t size() const { + // size_t sz = 1; + // auto dm = dims.size(); + // while (dm > 0) + // sz *= dims[--dm]; + // return sz; + // } + + // TensorType getType() const { return type; } + // void setType(TensorType ty) { type = ty; } + + // void print() { + // if (type == Invalid) { + // std::cout << "Invalid tensor" << std::endl; + // return; + // } + + // if (data == nullptr || dims.size() == 0) { + // std::cout << "Empty tensor" << std::endl; + // return; + // } + + // // TODO: can be uncommented after tensor's compute type is + // correctly set if (computed == NotComputed) { + // std::cout << "Uncomputed tensor" << std::endl; + // return; + // } + + // std::cout << "Tensor: " << guid << std::endl; + // auto numDims = dims.size(); + // auto dimSzVec = std::vector(numDims, 1); + // dimSzVec[numDims - 1] = dims[numDims - 1]; + // for (int i = numDims - 1; i != 0; --i) + // dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1]; + // for (size_t i = 0, iEnd = size(); i < iEnd; ++i) { + // for (size_t j = 0; j < numDims; ++j) { + // if (i % dimSzVec[j] == 0) { + // std::cout << "["; + // } + // } + // std::cout << data[i]; + // for (size_t j = 0; j < numDims; ++j) { + // if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) { + // std::cout << "]"; + // } + // } + // if (i != size() - 1) + // std::cout << ", "; + // if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - + // 1) + // std::cout << std::endl; + // } + // } + + // static inline void initFastrand() { + // assert(omp_get_max_threads() <= 256); + // // srand(0); // constant seed for test + // // align random_seed to avoid false sharing + // for (int i = 0; i < 256 * 16; ++i) { + // // random_seed[i] = rand(); + // // constant random seed for test + // random_seed[i] = i; + // } + // random_inited = true; + // } + + // static inline int fastrand(int &g_seed) { + // g_seed = (214013 * g_seed + 2531011); + // return (g_seed >> 16) & 0x7FFF; + // } + + // std::vector> const *getSplittingPoints() const { + // assert(!splittingPoints.empty()); + // return &splittingPoints; + // } + + // bool setSplittingPoints(std::vector> value) { + // assert(!value.empty()); + // splittingPoints = value; + // return true; + // } + + // void printSplittingPoints() { + // if (splittingPoints.empty()) + // printf("Empty SplittingPoints"); + // else { + // printf("["); + // for (auto &vs : splittingPoints) { + // printf("["); + // for (auto v : vs) + // printf("%2d,", v); + // printf("],"); + // } + // printf("]"); + // } + // } + + // void initSplittingPoints() { + // splittingPoints.resize(getDims().size()); } + + // void printShape(); +}; + +} // namespace it \ No newline at end of file diff --git a/include/test.h b/include/test.h new file mode 100644 index 00000000..05bcdcb1 --- /dev/null +++ b/include/test.h @@ -0,0 +1,3 @@ +#pragma once +#include "core/common.h" +#include "gtest/gtest.h" diff --git a/src/core/graph.cc b/src/core/graph.cc new file mode 100644 index 00000000..bf78d5a2 --- /dev/null +++ b/src/core/graph.cc @@ -0,0 +1,15 @@ +#include "core/graph.h" + +namespace it { + +void GraphNode::updateConnection() { + // TODO +} + +string GraphNode::toString() const { + std::ostringstream oss; + oss << "GraphNode: "; + return oss.str(); +} + +} // namespace it \ No newline at end of file diff --git a/src/core/operator.cc b/src/core/operator.cc new file mode 100644 index 00000000..5817d55c --- /dev/null +++ b/src/core/operator.cc @@ -0,0 +1,11 @@ +#include "core/operator.h" + +namespace it { + +string OperatorNode::toString() const { + std::ostringstream oss; + oss << "Operator: "; + return oss.str(); +} + +} // namespace it \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc new file mode 100644 index 00000000..60710c34 --- /dev/null +++ b/src/core/tensor.cc @@ -0,0 +1,8 @@ +#include +namespace it { + +string TensorBaseNode::toString() const { + return "TensorBaseNode " + std::to_string(guid); +} + +}; // namespace it \ No newline at end of file diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc new file mode 100644 index 00000000..41587725 --- /dev/null +++ b/test/core/test_graph.cc @@ -0,0 +1,12 @@ +#include "core/graph.h" +#include "test.h" + +namespace it { + +TEST(Graph, build) { + Graph g = make_ref(); + g->addOp(make_ref(TensorVec{}, TensorVec{})); + g->print(); +} + +} // namespace it \ No newline at end of file