forked from jiuyuan/InfiniTensor
Add: graph, tensor, and operator
This commit is contained in:
parent
b89495a782
commit
e6101b0336
|
@ -0,0 +1,90 @@
|
||||||
|
---
|
||||||
|
Language: Cpp
|
||||||
|
# BasedOnStyle: LLVM
|
||||||
|
AccessModifierOffset: -2
|
||||||
|
AlignAfterOpenBracket: Align
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
AlignEscapedNewlinesLeft: false
|
||||||
|
AlignOperands: true
|
||||||
|
AlignTrailingComments: true
|
||||||
|
AllowAllParametersOfDeclarationOnNextLine: true
|
||||||
|
AllowShortBlocksOnASingleLine: false
|
||||||
|
AllowShortCaseLabelsOnASingleLine: false
|
||||||
|
AllowShortFunctionsOnASingleLine: All
|
||||||
|
AllowShortIfStatementsOnASingleLine: false
|
||||||
|
AllowShortLoopsOnASingleLine: false
|
||||||
|
AlwaysBreakAfterDefinitionReturnType: None
|
||||||
|
AlwaysBreakAfterReturnType: None
|
||||||
|
AlwaysBreakBeforeMultilineStrings: false
|
||||||
|
AlwaysBreakTemplateDeclarations: false
|
||||||
|
BinPackArguments: true
|
||||||
|
BinPackParameters: true
|
||||||
|
BraceWrapping:
|
||||||
|
AfterClass: false
|
||||||
|
AfterControlStatement: false
|
||||||
|
AfterEnum: false
|
||||||
|
AfterFunction: false
|
||||||
|
AfterNamespace: false
|
||||||
|
AfterObjCDeclaration: false
|
||||||
|
AfterStruct: false
|
||||||
|
AfterUnion: false
|
||||||
|
BeforeCatch: false
|
||||||
|
BeforeElse: false
|
||||||
|
IndentBraces: false
|
||||||
|
BreakBeforeBinaryOperators: None
|
||||||
|
BreakBeforeBraces: Attach
|
||||||
|
BreakBeforeTernaryOperators: true
|
||||||
|
BreakConstructorInitializersBeforeComma: false
|
||||||
|
ColumnLimit: 80
|
||||||
|
CommentPragmas: '^ IWYU pragma:'
|
||||||
|
ConstructorInitializerAllOnOneLineOrOnePerLine: false
|
||||||
|
ConstructorInitializerIndentWidth: 4
|
||||||
|
ContinuationIndentWidth: 4
|
||||||
|
Cpp11BracedListStyle: true
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
DisableFormat: false
|
||||||
|
ExperimentalAutoDetectBinPacking: false
|
||||||
|
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '^(<|"(gtest|isl|json)/)'
|
||||||
|
Priority: 3
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 1
|
||||||
|
IndentCaseLabels: false
|
||||||
|
IndentWidth: 4
|
||||||
|
IndentWrappedFunctionNames: false
|
||||||
|
KeepEmptyLinesAtTheStartOfBlocks: true
|
||||||
|
MacroBlockBegin: ''
|
||||||
|
MacroBlockEnd: ''
|
||||||
|
MaxEmptyLinesToKeep: 1
|
||||||
|
NamespaceIndentation: None
|
||||||
|
ObjCBlockIndentWidth: 2
|
||||||
|
ObjCSpaceAfterProperty: false
|
||||||
|
ObjCSpaceBeforeProtocolList: true
|
||||||
|
PenaltyBreakBeforeFirstCallParameter: 19
|
||||||
|
PenaltyBreakComment: 300
|
||||||
|
PenaltyBreakFirstLessLess: 120
|
||||||
|
PenaltyBreakString: 1000
|
||||||
|
PenaltyExcessCharacter: 1000000
|
||||||
|
PenaltyReturnTypeOnItsOwnLine: 60
|
||||||
|
PointerAlignment: Right
|
||||||
|
ReflowComments: true
|
||||||
|
SortIncludes: true
|
||||||
|
SpaceAfterCStyleCast: false
|
||||||
|
SpaceBeforeAssignmentOperators: true
|
||||||
|
SpaceBeforeParens: ControlStatements
|
||||||
|
SpaceInEmptyParentheses: false
|
||||||
|
SpacesBeforeTrailingComments: 1
|
||||||
|
SpacesInAngles: false
|
||||||
|
SpacesInContainerLiterals: true
|
||||||
|
SpacesInCStyleCastParentheses: false
|
||||||
|
SpacesInParentheses: false
|
||||||
|
SpacesInSquareBrackets: false
|
||||||
|
Standard: Cpp11
|
||||||
|
TabWidth: 8
|
||||||
|
UseTab: Never
|
||||||
|
...
|
||||||
|
|
|
@ -0,0 +1,311 @@
|
||||||
|
{
|
||||||
|
"_help_parse": "Options affecting listfile parsing",
|
||||||
|
"parse": {
|
||||||
|
"_help_additional_commands": [
|
||||||
|
"Specify structure for custom cmake functions"
|
||||||
|
],
|
||||||
|
"additional_commands": {
|
||||||
|
"foo": {
|
||||||
|
"flags": [
|
||||||
|
"BAR",
|
||||||
|
"BAZ"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"HEADERS": "*",
|
||||||
|
"SOURCES": "*",
|
||||||
|
"DEPENDS": "*"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"_help_override_spec": [
|
||||||
|
"Override configurations per-command where available"
|
||||||
|
],
|
||||||
|
"override_spec": {},
|
||||||
|
"_help_vartags": [
|
||||||
|
"Specify variable tags."
|
||||||
|
],
|
||||||
|
"vartags": [],
|
||||||
|
"_help_proptags": [
|
||||||
|
"Specify property tags."
|
||||||
|
],
|
||||||
|
"proptags": []
|
||||||
|
},
|
||||||
|
"_help_format": "Options affecting formatting.",
|
||||||
|
"format": {
|
||||||
|
"_help_disable": [
|
||||||
|
"Disable formatting entirely, making cmake-format a no-op"
|
||||||
|
],
|
||||||
|
"disable": false,
|
||||||
|
"_help_line_width": [
|
||||||
|
"How wide to allow formatted cmake files"
|
||||||
|
],
|
||||||
|
"line_width": 1500,
|
||||||
|
"_help_tab_size": [
|
||||||
|
"How many spaces to tab for indent"
|
||||||
|
],
|
||||||
|
"tab_size": 2,
|
||||||
|
"_help_use_tabchars": [
|
||||||
|
"If true, lines are indented using tab characters (utf-8",
|
||||||
|
"0x09) instead of <tab_size> space characters (utf-8 0x20).",
|
||||||
|
"In cases where the layout would require a fractional tab",
|
||||||
|
"character, the behavior of the fractional indentation is",
|
||||||
|
"governed by <fractional_tab_policy>"
|
||||||
|
],
|
||||||
|
"use_tabchars": false,
|
||||||
|
"_help_fractional_tab_policy": [
|
||||||
|
"If <use_tabchars> is True, then the value of this variable",
|
||||||
|
"indicates how fractional indentions are handled during",
|
||||||
|
"whitespace replacement. If set to 'use-space', fractional",
|
||||||
|
"indentation is left as spaces (utf-8 0x20). If set to",
|
||||||
|
"`round-up` fractional indentation is replaced with a single",
|
||||||
|
"tab character (utf-8 0x09) effectively shifting the column",
|
||||||
|
"to the next tabstop"
|
||||||
|
],
|
||||||
|
"fractional_tab_policy": "use-space",
|
||||||
|
"_help_max_subgroups_hwrap": [
|
||||||
|
"If an argument group contains more than this many sub-groups",
|
||||||
|
"(parg or kwarg groups) then force it to a vertical layout."
|
||||||
|
],
|
||||||
|
"max_subgroups_hwrap": 2,
|
||||||
|
"_help_max_pargs_hwrap": [
|
||||||
|
"If a positional argument group contains more than this many",
|
||||||
|
"arguments, then force it to a vertical layout."
|
||||||
|
],
|
||||||
|
"max_pargs_hwrap": 6,
|
||||||
|
"_help_max_rows_cmdline": [
|
||||||
|
"If a cmdline positional group consumes more than this many",
|
||||||
|
"lines without nesting, then invalidate the layout (and nest)"
|
||||||
|
],
|
||||||
|
"max_rows_cmdline": 2,
|
||||||
|
"_help_separate_ctrl_name_with_space": [
|
||||||
|
"If true, separate flow control names from their parentheses",
|
||||||
|
"with a space"
|
||||||
|
],
|
||||||
|
"separate_ctrl_name_with_space": false,
|
||||||
|
"_help_separate_fn_name_with_space": [
|
||||||
|
"If true, separate function names from parentheses with a",
|
||||||
|
"space"
|
||||||
|
],
|
||||||
|
"separate_fn_name_with_space": false,
|
||||||
|
"_help_dangle_parens": [
|
||||||
|
"If a statement is wrapped to more than one line, than dangle",
|
||||||
|
"the closing parenthesis on its own line."
|
||||||
|
],
|
||||||
|
"dangle_parens": false,
|
||||||
|
"_help_dangle_align": [
|
||||||
|
"If the trailing parenthesis must be 'dangled' on its on",
|
||||||
|
"line, then align it to this reference: `prefix`: the start",
|
||||||
|
"of the statement, `prefix-indent`: the start of the",
|
||||||
|
"statement, plus one indentation level, `child`: align to",
|
||||||
|
"the column of the arguments"
|
||||||
|
],
|
||||||
|
"dangle_align": "prefix",
|
||||||
|
"_help_min_prefix_chars": [
|
||||||
|
"If the statement spelling length (including space and",
|
||||||
|
"parenthesis) is smaller than this amount, then force reject",
|
||||||
|
"nested layouts."
|
||||||
|
],
|
||||||
|
"min_prefix_chars": 4,
|
||||||
|
"_help_max_prefix_chars": [
|
||||||
|
"If the statement spelling length (including space and",
|
||||||
|
"parenthesis) is larger than the tab width by more than this",
|
||||||
|
"amount, then force reject un-nested layouts."
|
||||||
|
],
|
||||||
|
"max_prefix_chars": 10,
|
||||||
|
"_help_max_lines_hwrap": [
|
||||||
|
"If a candidate layout is wrapped horizontally but it exceeds",
|
||||||
|
"this many lines, then reject the layout."
|
||||||
|
],
|
||||||
|
"max_lines_hwrap": 2,
|
||||||
|
"_help_line_ending": [
|
||||||
|
"What style line endings to use in the output."
|
||||||
|
],
|
||||||
|
"line_ending": "unix",
|
||||||
|
"_help_command_case": [
|
||||||
|
"Format command names consistently as 'lower' or 'upper' case"
|
||||||
|
],
|
||||||
|
"command_case": "canonical",
|
||||||
|
"_help_keyword_case": [
|
||||||
|
"Format keywords consistently as 'lower' or 'upper' case"
|
||||||
|
],
|
||||||
|
"keyword_case": "unchanged",
|
||||||
|
"_help_always_wrap": [
|
||||||
|
"A list of command names which should always be wrapped"
|
||||||
|
],
|
||||||
|
"always_wrap": [],
|
||||||
|
"_help_enable_sort": [
|
||||||
|
"If true, the argument lists which are known to be sortable",
|
||||||
|
"will be sorted lexicographicall"
|
||||||
|
],
|
||||||
|
"enable_sort": true,
|
||||||
|
"_help_autosort": [
|
||||||
|
"If true, the parsers may infer whether or not an argument",
|
||||||
|
"list is sortable (without annotation)."
|
||||||
|
],
|
||||||
|
"autosort": false,
|
||||||
|
"_help_require_valid_layout": [
|
||||||
|
"By default, if cmake-format cannot successfully fit",
|
||||||
|
"everything into the desired linewidth it will apply the",
|
||||||
|
"last, most aggressive attempt that it made. If this flag is",
|
||||||
|
"True, however, cmake-format will print error, exit with non-",
|
||||||
|
"zero status code, and write-out nothing"
|
||||||
|
],
|
||||||
|
"require_valid_layout": false,
|
||||||
|
"_help_layout_passes": [
|
||||||
|
"A dictionary mapping layout nodes to a list of wrap",
|
||||||
|
"decisions. See the documentation for more information."
|
||||||
|
],
|
||||||
|
"layout_passes": {}
|
||||||
|
},
|
||||||
|
"_help_markup": "Options affecting comment reflow and formatting.",
|
||||||
|
"markup": {
|
||||||
|
"_help_bullet_char": [
|
||||||
|
"What character to use for bulleted lists"
|
||||||
|
],
|
||||||
|
"bullet_char": "*",
|
||||||
|
"_help_enum_char": [
|
||||||
|
"What character to use as punctuation after numerals in an",
|
||||||
|
"enumerated list"
|
||||||
|
],
|
||||||
|
"enum_char": ".",
|
||||||
|
"_help_first_comment_is_literal": [
|
||||||
|
"If comment markup is enabled, don't reflow the first comment",
|
||||||
|
"block in each listfile. Use this to preserve formatting of",
|
||||||
|
"your copyright/license statements."
|
||||||
|
],
|
||||||
|
"first_comment_is_literal": false,
|
||||||
|
"_help_literal_comment_pattern": [
|
||||||
|
"If comment markup is enabled, don't reflow any comment block",
|
||||||
|
"which matches this (regex) pattern. Default is `None`",
|
||||||
|
"(disabled)."
|
||||||
|
],
|
||||||
|
"literal_comment_pattern": ".*",
|
||||||
|
"_help_fence_pattern": [
|
||||||
|
"Regular expression to match preformat fences in comments",
|
||||||
|
"default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``"
|
||||||
|
],
|
||||||
|
"fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$",
|
||||||
|
"_help_ruler_pattern": [
|
||||||
|
"Regular expression to match rulers in comments default=",
|
||||||
|
"``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``"
|
||||||
|
],
|
||||||
|
"ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$",
|
||||||
|
"_help_explicit_trailing_pattern": [
|
||||||
|
"If a comment line matches starts with this pattern then it",
|
||||||
|
"is explicitly a trailing comment for the preceding argument.",
|
||||||
|
"Default is '#<'"
|
||||||
|
],
|
||||||
|
"explicit_trailing_pattern": "#<",
|
||||||
|
"_help_hashruler_min_length": [
|
||||||
|
"If a comment line starts with at least this many consecutive",
|
||||||
|
"hash characters, then don't lstrip() them off. This allows",
|
||||||
|
"for lazy hash rulers where the first hash char is not",
|
||||||
|
"separated by space"
|
||||||
|
],
|
||||||
|
"hashruler_min_length": 10,
|
||||||
|
"_help_canonicalize_hashrulers": [
|
||||||
|
"If true, then insert a space between the first hash char and",
|
||||||
|
"remaining hash chars in a hash ruler, and normalize its",
|
||||||
|
"length to fill the column"
|
||||||
|
],
|
||||||
|
"canonicalize_hashrulers": true,
|
||||||
|
"_help_enable_markup": [
|
||||||
|
"enable comment markup parsing and reflow"
|
||||||
|
],
|
||||||
|
"enable_markup": true
|
||||||
|
},
|
||||||
|
"_help_lint": "Options affecting the linter",
|
||||||
|
"lint": {
|
||||||
|
"_help_disabled_codes": [
|
||||||
|
"a list of lint codes to disable"
|
||||||
|
],
|
||||||
|
"disabled_codes": [],
|
||||||
|
"_help_function_pattern": [
|
||||||
|
"regular expression pattern describing valid function names"
|
||||||
|
],
|
||||||
|
"function_pattern": "[0-9a-z_]+",
|
||||||
|
"_help_macro_pattern": [
|
||||||
|
"regular expression pattern describing valid macro names"
|
||||||
|
],
|
||||||
|
"macro_pattern": "[0-9A-Z_]+",
|
||||||
|
"_help_global_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with global (cache) scope"
|
||||||
|
],
|
||||||
|
"global_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_internal_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with global scope (but internal semantic)"
|
||||||
|
],
|
||||||
|
"internal_var_pattern": "_[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_local_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"variables with local scope"
|
||||||
|
],
|
||||||
|
"local_var_pattern": "[a-z][a-z0-9_]+",
|
||||||
|
"_help_private_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"privatedirectory variables"
|
||||||
|
],
|
||||||
|
"private_var_pattern": "_[0-9a-z_]+",
|
||||||
|
"_help_public_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for public",
|
||||||
|
"directory variables"
|
||||||
|
],
|
||||||
|
"public_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_argument_var_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"function/macro arguments and loop variables."
|
||||||
|
],
|
||||||
|
"argument_var_pattern": "[a-z][a-z0-9_]+",
|
||||||
|
"_help_keyword_pattern": [
|
||||||
|
"regular expression pattern describing valid names for",
|
||||||
|
"keywords used in functions or macros"
|
||||||
|
],
|
||||||
|
"keyword_pattern": "[A-Z][0-9A-Z_]+",
|
||||||
|
"_help_max_conditionals_custom_parser": [
|
||||||
|
"In the heuristic for C0201, how many conditionals to match",
|
||||||
|
"within a loop in before considering the loop a parser."
|
||||||
|
],
|
||||||
|
"max_conditionals_custom_parser": 2,
|
||||||
|
"_help_min_statement_spacing": [
|
||||||
|
"Require at least this many newlines between statements"
|
||||||
|
],
|
||||||
|
"min_statement_spacing": 1,
|
||||||
|
"_help_max_statement_spacing": [
|
||||||
|
"Require no more than this many newlines between statements"
|
||||||
|
],
|
||||||
|
"max_statement_spacing": 2,
|
||||||
|
"max_returns": 6,
|
||||||
|
"max_branches": 12,
|
||||||
|
"max_arguments": 5,
|
||||||
|
"max_localvars": 15,
|
||||||
|
"max_statements": 50
|
||||||
|
},
|
||||||
|
"_help_encode": "Options affecting file encoding",
|
||||||
|
"encode": {
|
||||||
|
"_help_emit_byteorder_mark": [
|
||||||
|
"If true, emit the unicode byte-order mark (BOM) at the start",
|
||||||
|
"of the file"
|
||||||
|
],
|
||||||
|
"emit_byteorder_mark": false,
|
||||||
|
"_help_input_encoding": [
|
||||||
|
"Specify the encoding of the input file. Defaults to utf-8"
|
||||||
|
],
|
||||||
|
"input_encoding": "utf-8",
|
||||||
|
"_help_output_encoding": [
|
||||||
|
"Specify the encoding of the output file. Defaults to utf-8.",
|
||||||
|
"Note that cmake only claims to support utf-8 so be careful",
|
||||||
|
"when using anything else"
|
||||||
|
],
|
||||||
|
"output_encoding": "utf-8"
|
||||||
|
},
|
||||||
|
"_help_misc": "Miscellaneous configurations options.",
|
||||||
|
"misc": {
|
||||||
|
"_help_per_command": [
|
||||||
|
"A dictionary containing any per-command configuration",
|
||||||
|
"overrides. Currently only `command_case` is supported."
|
||||||
|
],
|
||||||
|
"per_command": {}
|
||||||
|
}
|
||||||
|
}
|
|
@ -30,3 +30,8 @@
|
||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
|
build/
|
||||||
|
build_debug/
|
||||||
|
|
||||||
|
.vscode/
|
|
@ -0,0 +1,9 @@
|
||||||
|
[submodule "3rd-party/pybind11"]
|
||||||
|
path = 3rd-party/pybind11
|
||||||
|
url = git@github.com:pybind/pybind11.git
|
||||||
|
[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"]
|
||||||
|
path = 3rd-party/nlohmann_json_cmake_fetchcontent
|
||||||
|
url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git
|
||||||
|
[submodule "3rd-party/googletest"]
|
||||||
|
path = 3rd-party/googletest
|
||||||
|
url = git@github.com:google/googletest.git
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae
|
|
@ -0,0 +1,73 @@
|
||||||
|
# TODO: check the minimum cmake version
|
||||||
|
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
|
||||||
|
include(CMakeDependentOption)
|
||||||
|
project(InfiniTensor C CXX)
|
||||||
|
|
||||||
|
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||||
|
option(BUILD_TEST "Build tests" OFF)
|
||||||
|
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||||
|
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF)
|
||||||
|
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF)
|
||||||
|
|
||||||
|
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||||
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||||
|
|
||||||
|
find_package(
|
||||||
|
Python
|
||||||
|
COMPONENTS Interpreter Development
|
||||||
|
REQUIRED)
|
||||||
|
find_package(CUDA REQUIRED)
|
||||||
|
# OpenMP
|
||||||
|
find_package(OpenMP)
|
||||||
|
if(OpenMP_C_FOUND)
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||||
|
endif()
|
||||||
|
if(OpenMP_CXX_FOUND)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# add_subdirectory(3rd-party/pybind11)
|
||||||
|
# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||||
|
|
||||||
|
include_directories(include)
|
||||||
|
# include_directories(3rd-party/pybind11/include)
|
||||||
|
# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||||
|
|
||||||
|
if(BUILD_TEST)
|
||||||
|
set(BUILD_GMOCK off)
|
||||||
|
set(INSTALL_GTEST off)
|
||||||
|
add_subdirectory(3rd-party/googletest)
|
||||||
|
include_directories(3rd-party/googletest/googletest/include)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
|
||||||
|
# file(GLOB_RECURSE TEST test/*.cc)
|
||||||
|
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
|
||||||
|
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
|
||||||
|
|
||||||
|
add_library(InfiniTensor SHARED ${SRC})
|
||||||
|
# Target
|
||||||
|
# cuda_add_library(it SHARED ${SRC})
|
||||||
|
# cuda_add_cublas_to_target(it) # cublas
|
||||||
|
# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed)
|
||||||
|
|
||||||
|
# # Python bindings
|
||||||
|
# pybind11_add_module(infini MODULE ${FFI})
|
||||||
|
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||||
|
|
||||||
|
if(BUILD_TEST)
|
||||||
|
enable_testing()
|
||||||
|
# Build all tests file( GLOB TEST_SOURCES test/test_sg2bmm.cc )
|
||||||
|
file(GLOB_RECURSE TEST_SOURCES test/*.cc)
|
||||||
|
foreach(testsourcefile ${TEST_SOURCES})
|
||||||
|
get_filename_component(testname ${testsourcefile} NAME_WE)
|
||||||
|
add_executable(${testname} ${testsourcefile})
|
||||||
|
target_link_libraries(${testname} InfiniTensor gtest_main gtest)
|
||||||
|
add_test(NAME ${testname} COMMAND ${testname})
|
||||||
|
endforeach(testsourcefile ${TEST_SOURCES})
|
||||||
|
endif()
|
|
@ -0,0 +1,32 @@
|
||||||
|
#pragma once
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <optional>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
using std::list;
|
||||||
|
using std::map;
|
||||||
|
using std::nullopt;
|
||||||
|
using std::optional;
|
||||||
|
using std::pair;
|
||||||
|
using std::set;
|
||||||
|
using std::string;
|
||||||
|
using std::tie;
|
||||||
|
using std::to_string;
|
||||||
|
using std::tuple;
|
||||||
|
using std::unordered_map;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
// Aliases
|
||||||
|
using dtype = float;
|
||||||
|
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,36 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
#include "core/tensor.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
class GraphNode : public Object {
|
||||||
|
protected:
|
||||||
|
TensorVec tensors;
|
||||||
|
TensorVec inputs;
|
||||||
|
TensorVec outputs;
|
||||||
|
OpVec ops;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Graph(OpVec oplist);
|
||||||
|
string toString() const override;
|
||||||
|
|
||||||
|
void addOp(Operator op) { ops.push_back(op); };
|
||||||
|
const TensorVec &getTensors() const { return tensors; }
|
||||||
|
const TensorVec &getInputs() const { return inputs; }
|
||||||
|
const TensorVec &getOutputs() const { return outputs; }
|
||||||
|
const OpVec &getOperators() const { return ops; }
|
||||||
|
// TensorVec &getInputs();
|
||||||
|
// TensorVec &getOutputs();
|
||||||
|
|
||||||
|
void updateConnection();
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
// bool compute();
|
||||||
|
|
||||||
|
// TODO: move to another class
|
||||||
|
// bool exportOnnx(const char *path);
|
||||||
|
// bool importOnnx(const char *net);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,45 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
using GuidBaseType = int;
|
||||||
|
|
||||||
|
class Guid {
|
||||||
|
private:
|
||||||
|
GuidBaseType guid;
|
||||||
|
|
||||||
|
private:
|
||||||
|
GuidBaseType generateGuid() {
|
||||||
|
static GuidBaseType guidCnt = 0;
|
||||||
|
return ++guidCnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Guid() { guid = generateGuid(); }
|
||||||
|
Guid(const Guid &rhs) { guid = generateGuid(); }
|
||||||
|
Guid &operator=(const Guid &rhs) {
|
||||||
|
guid = generateGuid();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator GuidBaseType() const { return guid; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class Object {
|
||||||
|
protected:
|
||||||
|
Guid guid;
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual ~Object(){};
|
||||||
|
virtual string toString() const = 0;
|
||||||
|
void print() { std::cout << toString() << std::endl; }
|
||||||
|
Guid getGuid() const { return guid; }
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
|
||||||
|
os << obj.toString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/tensor.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
class OperatorNode : public Object {
|
||||||
|
protected:
|
||||||
|
// OpType type;
|
||||||
|
TensorVec inputs;
|
||||||
|
TensorVec outputs;
|
||||||
|
// vector<WRef<Operator>> predecessors;
|
||||||
|
// vector<WRef<Operator>> successors;
|
||||||
|
public:
|
||||||
|
OperatorNode(TensorVec inputs, TensorVec outputs)
|
||||||
|
: inputs(inputs), outputs(outputs) {}
|
||||||
|
string toString() const override;
|
||||||
|
// Operator(TensorVec inputs) : inputs(inputs) {}
|
||||||
|
|
||||||
|
virtual ~OperatorNode() {}
|
||||||
|
};
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,27 @@
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
#include <functional> // hash
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
template <typename T> using Ref = std::shared_ptr<T>;
|
||||||
|
template <typename T> using WRef = std::weak_ptr<T>;
|
||||||
|
|
||||||
|
template <typename T> struct is_ref : std::false_type {};
|
||||||
|
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
|
||||||
|
template <typename T> struct is_ref<WRef<T>> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
|
||||||
|
static_assert(is_ref<T>::value == false, "Ref should not be nested");
|
||||||
|
return std::make_shared<T>(std::forward<Params>(params)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T, class U,
|
||||||
|
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
|
||||||
|
Ref<T> as(const Ref<U> &ref) {
|
||||||
|
return std::dynamic_pointer_cast<T>(ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,368 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/object.h"
|
||||||
|
#include "core/ref.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
// class Tensor;
|
||||||
|
class TensorBaseNode;
|
||||||
|
class OperatorNode;
|
||||||
|
class GraphNode;
|
||||||
|
|
||||||
|
using TensorBase = Ref<TensorBaseNode>;
|
||||||
|
using Operator = Ref<OperatorNode>;
|
||||||
|
using Graph = Ref<GraphNode>;
|
||||||
|
|
||||||
|
using TensorVec = vector<TensorBase>;
|
||||||
|
using OpVec = vector<Operator>;
|
||||||
|
|
||||||
|
// using TensorMap = std::map<size_t, Tensor *>;
|
||||||
|
// using OpMap = std::map<size_t, Operator *>;
|
||||||
|
using VType = uint32_t;
|
||||||
|
// using SplittingPoints = std::vector<std::vector<int>>;
|
||||||
|
|
||||||
|
class TensorBaseNode : public Object {
|
||||||
|
public:
|
||||||
|
enum DataType {
|
||||||
|
Float32,
|
||||||
|
Int32,
|
||||||
|
};
|
||||||
|
|
||||||
|
// enum TensorType {
|
||||||
|
// Input,
|
||||||
|
// Weight,
|
||||||
|
// Invalid,
|
||||||
|
// NotCounted,
|
||||||
|
// };
|
||||||
|
|
||||||
|
// // TODO: is more compute state needed?
|
||||||
|
// enum ComputeState {
|
||||||
|
// NotComputed,
|
||||||
|
// // Allocated,
|
||||||
|
// // Initialized,
|
||||||
|
// // ComputedPartial,
|
||||||
|
// ComputedFull,
|
||||||
|
// };
|
||||||
|
|
||||||
|
private:
|
||||||
|
int hid;
|
||||||
|
// uint64_t hash;
|
||||||
|
// Shape shape;
|
||||||
|
int dim;
|
||||||
|
|
||||||
|
vector<WRef<TensorBase>> inputOf;
|
||||||
|
WRef<TensorBase> outputOf;
|
||||||
|
Ref<VType> data;
|
||||||
|
DataType dtype;
|
||||||
|
// ComputeState computed;
|
||||||
|
// static int random_seed[256 * 16];
|
||||||
|
// static bool random_inited;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Tensor(TensorType type = Input, DataType dtype = Float32)
|
||||||
|
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
|
||||||
|
// data(nullptr), dtype(dtype), type(type), computed(NotComputed) {}
|
||||||
|
// Tensor(const Dim &dims, TensorType type = Input, DataType dtype =
|
||||||
|
// Float32)
|
||||||
|
// : guid(generateGuid()), hash(generateHash()), dims(dims),
|
||||||
|
// outputOf(nullptr), data(nullptr), dtype(dtype), type(type),
|
||||||
|
// computed(NotComputed) {
|
||||||
|
// itInit();
|
||||||
|
// }
|
||||||
|
// Tensor(const Tensor &rhs) : Tensor(rhs.dims, rhs.type, rhs.dtype) {
|
||||||
|
// outputOf = nullptr;
|
||||||
|
// data = nullptr;
|
||||||
|
// hash = rhs.hash;
|
||||||
|
// dimPenalty = rhs.dimPenalty;
|
||||||
|
// itInit();
|
||||||
|
// }
|
||||||
|
// Tensor(VType scalar, TensorType type = Weight, DataType dtype = Float32)
|
||||||
|
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
|
||||||
|
// data(nullptr), dtype(dtype), type(type), computed(ComputedFull) {
|
||||||
|
// assert(size() == 1);
|
||||||
|
// dataMalloc();
|
||||||
|
// data[0] = scalar;
|
||||||
|
// }
|
||||||
|
virtual ~TensorBaseNode() {}
|
||||||
|
string toString() const override;
|
||||||
|
|
||||||
|
// // inputOf and outputOf will not be cloned
|
||||||
|
// Tensor *clone() {
|
||||||
|
// Tensor *t = new Tensor(*this);
|
||||||
|
// return t;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void clone(Tensor *t) {
|
||||||
|
// dims = t->dims;
|
||||||
|
// dtype = t->dtype;
|
||||||
|
// type = t->type;
|
||||||
|
// hash = t->hash;
|
||||||
|
// dimPenalty = t->dimPenalty;
|
||||||
|
// }
|
||||||
|
|
||||||
|
DataType getDType() const { return dtype; }
|
||||||
|
|
||||||
|
// uint64_t getHash() const { return hash; }
|
||||||
|
|
||||||
|
// void setInputOf(const OpVec &ops) {
|
||||||
|
// inputOf.clear();
|
||||||
|
// for (const auto &op : ops)
|
||||||
|
// inputOf.emplace_back(op);
|
||||||
|
// }
|
||||||
|
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
|
||||||
|
// void setOutputOf(Operator op) { outputOf = op; }
|
||||||
|
|
||||||
|
// const OpVec &getInputOf() { return inputOf; }
|
||||||
|
// Operator *getOutputOf() { return outputOf; }
|
||||||
|
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||||
|
|
||||||
|
// bool dataMalloc() {
|
||||||
|
// if (data == nullptr)
|
||||||
|
// data = new VType[size()];
|
||||||
|
// return data != nullptr;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const Dim &getDims() const { return dims; }
|
||||||
|
// void setDims(const Dim &dms) { dims = dms; }
|
||||||
|
|
||||||
|
// bool dataRand(int seed = 0) {
|
||||||
|
// if (data == nullptr)
|
||||||
|
// data = new VType[size()];
|
||||||
|
// if (!random_inited)
|
||||||
|
// initFastrand();
|
||||||
|
// // srand(seed);
|
||||||
|
// // faster rand generator; parallel
|
||||||
|
// size_t iEnd = size();
|
||||||
|
// // std::cerr << "Init beginned " << std::endl;
|
||||||
|
// #pragma omp parallel for
|
||||||
|
// for (size_t i = 0; i < iEnd; ++i)
|
||||||
|
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
||||||
|
// 10000;
|
||||||
|
// // std::cerr << "Init finished" << std::endl;
|
||||||
|
// computed = ComputedFull;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(VType *dptr) {
|
||||||
|
// if (dptr == nullptr)
|
||||||
|
// return false;
|
||||||
|
// auto sz = size();
|
||||||
|
// #pragma omp parallel for
|
||||||
|
// for (size_t i = 0; i < sz; ++i)
|
||||||
|
// data[i] = dptr[i];
|
||||||
|
// computed = ComputedFull;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setScalar(VType val) {
|
||||||
|
// if (data == nullptr || !dims.empty())
|
||||||
|
// return false;
|
||||||
|
// data[0] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(const Dim &ds, VType val) {
|
||||||
|
// if (data == nullptr || ds.size() != dims.size())
|
||||||
|
// return false;
|
||||||
|
// data[getOffset(ds)] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setData(size_t pos, VType val) {
|
||||||
|
// if (data == nullptr || pos >= size())
|
||||||
|
// return false;
|
||||||
|
// data[pos] = val;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
||||||
|
|
||||||
|
// VType getData(const Dim &ds) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// auto offset = getOffset(ds);
|
||||||
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getData(size_t pos) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// assert(pos < size());
|
||||||
|
// return data[pos];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType *getDataPtr() const { return data; }
|
||||||
|
|
||||||
|
// size_t getOffset(const Dim &ds) {
|
||||||
|
// auto nDim = ds.size();
|
||||||
|
// assert(dims.size() == nDim);
|
||||||
|
// if (ds.empty())
|
||||||
|
// return 0;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// if (ds[i] < 0 || ds[i] >= dims[i])
|
||||||
|
// return (size_t)-1;
|
||||||
|
// size_t idx = ds[0];
|
||||||
|
// size_t dm = 0;
|
||||||
|
// while (++dm < nDim)
|
||||||
|
// idx = idx * dims[dm] + ds[dm];
|
||||||
|
// return idx;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getBroadcastData(const Dim &ds) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// auto offset = getBroadcastOffset(ds);
|
||||||
|
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// VType getBroadcastData(size_t pos) {
|
||||||
|
// assert(data != nullptr);
|
||||||
|
// return data[pos % size()];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// size_t getBroadcastOffset(const Dim &ds) {
|
||||||
|
// assert(ds.size() >= dims.size());
|
||||||
|
// auto nDim = dims.size();
|
||||||
|
// auto nBroadcastDim = ds.size() - nDim;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
||||||
|
// dims[i])
|
||||||
|
// return (size_t)-1;
|
||||||
|
// size_t idx = 0;
|
||||||
|
// for (size_t i = 0; i < nDim; ++i)
|
||||||
|
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
||||||
|
// return idx;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void itInit() { it = Dim(dims.size(), 0); }
|
||||||
|
|
||||||
|
// void itReset() {
|
||||||
|
// itInit();
|
||||||
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||||
|
// it[i] = 0;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool itValid() {
|
||||||
|
// if (it.size() != dims.size())
|
||||||
|
// return false;
|
||||||
|
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||||
|
// if (it[i] >= dims[i])
|
||||||
|
// return false;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const Dim &itGet() { return it; }
|
||||||
|
|
||||||
|
// void itNext() {
|
||||||
|
// auto p = it.size() - 1;
|
||||||
|
// it[p] += 1;
|
||||||
|
// while (p >= 1) {
|
||||||
|
// if (it[p] == dims[p]) {
|
||||||
|
// it[p] = 0;
|
||||||
|
// it[--p] += 1;
|
||||||
|
// } else
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// size_t size() const {
|
||||||
|
// size_t sz = 1;
|
||||||
|
// auto dm = dims.size();
|
||||||
|
// while (dm > 0)
|
||||||
|
// sz *= dims[--dm];
|
||||||
|
// return sz;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TensorType getType() const { return type; }
|
||||||
|
// void setType(TensorType ty) { type = ty; }
|
||||||
|
|
||||||
|
// void print() {
|
||||||
|
// if (type == Invalid) {
|
||||||
|
// std::cout << "Invalid tensor" << std::endl;
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if (data == nullptr || dims.size() == 0) {
|
||||||
|
// std::cout << "Empty tensor" << std::endl;
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // TODO: can be uncommented after tensor's compute type is
|
||||||
|
// correctly set if (computed == NotComputed) {
|
||||||
|
// std::cout << "Uncomputed tensor" << std::endl;
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// std::cout << "Tensor: " << guid << std::endl;
|
||||||
|
// auto numDims = dims.size();
|
||||||
|
// auto dimSzVec = std::vector<int>(numDims, 1);
|
||||||
|
// dimSzVec[numDims - 1] = dims[numDims - 1];
|
||||||
|
// for (int i = numDims - 1; i != 0; --i)
|
||||||
|
// dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1];
|
||||||
|
// for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
||||||
|
// for (size_t j = 0; j < numDims; ++j) {
|
||||||
|
// if (i % dimSzVec[j] == 0) {
|
||||||
|
// std::cout << "[";
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// std::cout << data[i];
|
||||||
|
// for (size_t j = 0; j < numDims; ++j) {
|
||||||
|
// if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
|
||||||
|
// std::cout << "]";
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if (i != size() - 1)
|
||||||
|
// std::cout << ", ";
|
||||||
|
// if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] -
|
||||||
|
// 1)
|
||||||
|
// std::cout << std::endl;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// static inline void initFastrand() {
|
||||||
|
// assert(omp_get_max_threads() <= 256);
|
||||||
|
// // srand(0); // constant seed for test
|
||||||
|
// // align random_seed to avoid false sharing
|
||||||
|
// for (int i = 0; i < 256 * 16; ++i) {
|
||||||
|
// // random_seed[i] = rand();
|
||||||
|
// // constant random seed for test
|
||||||
|
// random_seed[i] = i;
|
||||||
|
// }
|
||||||
|
// random_inited = true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// static inline int fastrand(int &g_seed) {
|
||||||
|
// g_seed = (214013 * g_seed + 2531011);
|
||||||
|
// return (g_seed >> 16) & 0x7FFF;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
||||||
|
// assert(!splittingPoints.empty());
|
||||||
|
// return &splittingPoints;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
||||||
|
// assert(!value.empty());
|
||||||
|
// splittingPoints = value;
|
||||||
|
// return true;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void printSplittingPoints() {
|
||||||
|
// if (splittingPoints.empty())
|
||||||
|
// printf("Empty SplittingPoints");
|
||||||
|
// else {
|
||||||
|
// printf("[");
|
||||||
|
// for (auto &vs : splittingPoints) {
|
||||||
|
// printf("[");
|
||||||
|
// for (auto v : vs)
|
||||||
|
// printf("%2d,", v);
|
||||||
|
// printf("],");
|
||||||
|
// }
|
||||||
|
// printf("]");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void initSplittingPoints() {
|
||||||
|
// splittingPoints.resize(getDims().size()); }
|
||||||
|
|
||||||
|
// void printShape();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,3 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "gtest/gtest.h"
|
|
@ -0,0 +1,15 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
void GraphNode::updateConnection() {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
string GraphNode::toString() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "GraphNode: ";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,11 @@
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
string OperatorNode::toString() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "Operator: ";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace it
|
|
@ -0,0 +1,8 @@
|
||||||
|
#include <core/tensor.h>
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
string TensorBaseNode::toString() const {
|
||||||
|
return "TensorBaseNode " + std::to_string(guid);
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace it
|
|
@ -0,0 +1,12 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace it {
|
||||||
|
|
||||||
|
TEST(Graph, build) {
|
||||||
|
Graph g = make_ref<GraphNode>();
|
||||||
|
g->addOp(make_ref<OperatorNode>(TensorVec{}, TensorVec{}));
|
||||||
|
g->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace it
|
Loading…
Reference in New Issue