forked from jiuyuan/InfiniTensor
Add: graph, tensor, and operator
This commit is contained in:
parent
b89495a782
commit
e6101b0336
|
@ -0,0 +1,90 @@
|
|||
---
|
||||
Language: Cpp
|
||||
# BasedOnStyle: LLVM
|
||||
AccessModifierOffset: -2
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlinesLeft: false
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: false
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: false
|
||||
AlwaysBreakTemplateDeclarations: false
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterClass: false
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
ColumnLimit: 80
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: false
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||
IncludeCategories:
|
||||
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
||||
Priority: 2
|
||||
- Regex: '^(<|"(gtest|isl|json)/)'
|
||||
Priority: 3
|
||||
- Regex: '.*'
|
||||
Priority: 1
|
||||
IndentCaseLabels: false
|
||||
IndentWidth: 4
|
||||
IndentWrappedFunctionNames: false
|
||||
KeepEmptyLinesAtTheStartOfBlocks: true
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: true
|
||||
PenaltyBreakBeforeFirstCallParameter: 19
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 60
|
||||
PointerAlignment: Right
|
||||
ReflowComments: true
|
||||
SortIncludes: true
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 1
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Cpp11
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
...
|
||||
|
|
@ -0,0 +1,311 @@
|
|||
{
|
||||
"_help_parse": "Options affecting listfile parsing",
|
||||
"parse": {
|
||||
"_help_additional_commands": [
|
||||
"Specify structure for custom cmake functions"
|
||||
],
|
||||
"additional_commands": {
|
||||
"foo": {
|
||||
"flags": [
|
||||
"BAR",
|
||||
"BAZ"
|
||||
],
|
||||
"kwargs": {
|
||||
"HEADERS": "*",
|
||||
"SOURCES": "*",
|
||||
"DEPENDS": "*"
|
||||
}
|
||||
}
|
||||
},
|
||||
"_help_override_spec": [
|
||||
"Override configurations per-command where available"
|
||||
],
|
||||
"override_spec": {},
|
||||
"_help_vartags": [
|
||||
"Specify variable tags."
|
||||
],
|
||||
"vartags": [],
|
||||
"_help_proptags": [
|
||||
"Specify property tags."
|
||||
],
|
||||
"proptags": []
|
||||
},
|
||||
"_help_format": "Options affecting formatting.",
|
||||
"format": {
|
||||
"_help_disable": [
|
||||
"Disable formatting entirely, making cmake-format a no-op"
|
||||
],
|
||||
"disable": false,
|
||||
"_help_line_width": [
|
||||
"How wide to allow formatted cmake files"
|
||||
],
|
||||
"line_width": 1500,
|
||||
"_help_tab_size": [
|
||||
"How many spaces to tab for indent"
|
||||
],
|
||||
"tab_size": 2,
|
||||
"_help_use_tabchars": [
|
||||
"If true, lines are indented using tab characters (utf-8",
|
||||
"0x09) instead of <tab_size> space characters (utf-8 0x20).",
|
||||
"In cases where the layout would require a fractional tab",
|
||||
"character, the behavior of the fractional indentation is",
|
||||
"governed by <fractional_tab_policy>"
|
||||
],
|
||||
"use_tabchars": false,
|
||||
"_help_fractional_tab_policy": [
|
||||
"If <use_tabchars> is True, then the value of this variable",
|
||||
"indicates how fractional indentions are handled during",
|
||||
"whitespace replacement. If set to 'use-space', fractional",
|
||||
"indentation is left as spaces (utf-8 0x20). If set to",
|
||||
"`round-up` fractional indentation is replaced with a single",
|
||||
"tab character (utf-8 0x09) effectively shifting the column",
|
||||
"to the next tabstop"
|
||||
],
|
||||
"fractional_tab_policy": "use-space",
|
||||
"_help_max_subgroups_hwrap": [
|
||||
"If an argument group contains more than this many sub-groups",
|
||||
"(parg or kwarg groups) then force it to a vertical layout."
|
||||
],
|
||||
"max_subgroups_hwrap": 2,
|
||||
"_help_max_pargs_hwrap": [
|
||||
"If a positional argument group contains more than this many",
|
||||
"arguments, then force it to a vertical layout."
|
||||
],
|
||||
"max_pargs_hwrap": 6,
|
||||
"_help_max_rows_cmdline": [
|
||||
"If a cmdline positional group consumes more than this many",
|
||||
"lines without nesting, then invalidate the layout (and nest)"
|
||||
],
|
||||
"max_rows_cmdline": 2,
|
||||
"_help_separate_ctrl_name_with_space": [
|
||||
"If true, separate flow control names from their parentheses",
|
||||
"with a space"
|
||||
],
|
||||
"separate_ctrl_name_with_space": false,
|
||||
"_help_separate_fn_name_with_space": [
|
||||
"If true, separate function names from parentheses with a",
|
||||
"space"
|
||||
],
|
||||
"separate_fn_name_with_space": false,
|
||||
"_help_dangle_parens": [
|
||||
"If a statement is wrapped to more than one line, than dangle",
|
||||
"the closing parenthesis on its own line."
|
||||
],
|
||||
"dangle_parens": false,
|
||||
"_help_dangle_align": [
|
||||
"If the trailing parenthesis must be 'dangled' on its on",
|
||||
"line, then align it to this reference: `prefix`: the start",
|
||||
"of the statement, `prefix-indent`: the start of the",
|
||||
"statement, plus one indentation level, `child`: align to",
|
||||
"the column of the arguments"
|
||||
],
|
||||
"dangle_align": "prefix",
|
||||
"_help_min_prefix_chars": [
|
||||
"If the statement spelling length (including space and",
|
||||
"parenthesis) is smaller than this amount, then force reject",
|
||||
"nested layouts."
|
||||
],
|
||||
"min_prefix_chars": 4,
|
||||
"_help_max_prefix_chars": [
|
||||
"If the statement spelling length (including space and",
|
||||
"parenthesis) is larger than the tab width by more than this",
|
||||
"amount, then force reject un-nested layouts."
|
||||
],
|
||||
"max_prefix_chars": 10,
|
||||
"_help_max_lines_hwrap": [
|
||||
"If a candidate layout is wrapped horizontally but it exceeds",
|
||||
"this many lines, then reject the layout."
|
||||
],
|
||||
"max_lines_hwrap": 2,
|
||||
"_help_line_ending": [
|
||||
"What style line endings to use in the output."
|
||||
],
|
||||
"line_ending": "unix",
|
||||
"_help_command_case": [
|
||||
"Format command names consistently as 'lower' or 'upper' case"
|
||||
],
|
||||
"command_case": "canonical",
|
||||
"_help_keyword_case": [
|
||||
"Format keywords consistently as 'lower' or 'upper' case"
|
||||
],
|
||||
"keyword_case": "unchanged",
|
||||
"_help_always_wrap": [
|
||||
"A list of command names which should always be wrapped"
|
||||
],
|
||||
"always_wrap": [],
|
||||
"_help_enable_sort": [
|
||||
"If true, the argument lists which are known to be sortable",
|
||||
"will be sorted lexicographicall"
|
||||
],
|
||||
"enable_sort": true,
|
||||
"_help_autosort": [
|
||||
"If true, the parsers may infer whether or not an argument",
|
||||
"list is sortable (without annotation)."
|
||||
],
|
||||
"autosort": false,
|
||||
"_help_require_valid_layout": [
|
||||
"By default, if cmake-format cannot successfully fit",
|
||||
"everything into the desired linewidth it will apply the",
|
||||
"last, most aggressive attempt that it made. If this flag is",
|
||||
"True, however, cmake-format will print error, exit with non-",
|
||||
"zero status code, and write-out nothing"
|
||||
],
|
||||
"require_valid_layout": false,
|
||||
"_help_layout_passes": [
|
||||
"A dictionary mapping layout nodes to a list of wrap",
|
||||
"decisions. See the documentation for more information."
|
||||
],
|
||||
"layout_passes": {}
|
||||
},
|
||||
"_help_markup": "Options affecting comment reflow and formatting.",
|
||||
"markup": {
|
||||
"_help_bullet_char": [
|
||||
"What character to use for bulleted lists"
|
||||
],
|
||||
"bullet_char": "*",
|
||||
"_help_enum_char": [
|
||||
"What character to use as punctuation after numerals in an",
|
||||
"enumerated list"
|
||||
],
|
||||
"enum_char": ".",
|
||||
"_help_first_comment_is_literal": [
|
||||
"If comment markup is enabled, don't reflow the first comment",
|
||||
"block in each listfile. Use this to preserve formatting of",
|
||||
"your copyright/license statements."
|
||||
],
|
||||
"first_comment_is_literal": false,
|
||||
"_help_literal_comment_pattern": [
|
||||
"If comment markup is enabled, don't reflow any comment block",
|
||||
"which matches this (regex) pattern. Default is `None`",
|
||||
"(disabled)."
|
||||
],
|
||||
"literal_comment_pattern": ".*",
|
||||
"_help_fence_pattern": [
|
||||
"Regular expression to match preformat fences in comments",
|
||||
"default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``"
|
||||
],
|
||||
"fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$",
|
||||
"_help_ruler_pattern": [
|
||||
"Regular expression to match rulers in comments default=",
|
||||
"``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``"
|
||||
],
|
||||
"ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$",
|
||||
"_help_explicit_trailing_pattern": [
|
||||
"If a comment line matches starts with this pattern then it",
|
||||
"is explicitly a trailing comment for the preceding argument.",
|
||||
"Default is '#<'"
|
||||
],
|
||||
"explicit_trailing_pattern": "#<",
|
||||
"_help_hashruler_min_length": [
|
||||
"If a comment line starts with at least this many consecutive",
|
||||
"hash characters, then don't lstrip() them off. This allows",
|
||||
"for lazy hash rulers where the first hash char is not",
|
||||
"separated by space"
|
||||
],
|
||||
"hashruler_min_length": 10,
|
||||
"_help_canonicalize_hashrulers": [
|
||||
"If true, then insert a space between the first hash char and",
|
||||
"remaining hash chars in a hash ruler, and normalize its",
|
||||
"length to fill the column"
|
||||
],
|
||||
"canonicalize_hashrulers": true,
|
||||
"_help_enable_markup": [
|
||||
"enable comment markup parsing and reflow"
|
||||
],
|
||||
"enable_markup": true
|
||||
},
|
||||
"_help_lint": "Options affecting the linter",
|
||||
"lint": {
|
||||
"_help_disabled_codes": [
|
||||
"a list of lint codes to disable"
|
||||
],
|
||||
"disabled_codes": [],
|
||||
"_help_function_pattern": [
|
||||
"regular expression pattern describing valid function names"
|
||||
],
|
||||
"function_pattern": "[0-9a-z_]+",
|
||||
"_help_macro_pattern": [
|
||||
"regular expression pattern describing valid macro names"
|
||||
],
|
||||
"macro_pattern": "[0-9A-Z_]+",
|
||||
"_help_global_var_pattern": [
|
||||
"regular expression pattern describing valid names for",
|
||||
"variables with global (cache) scope"
|
||||
],
|
||||
"global_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||
"_help_internal_var_pattern": [
|
||||
"regular expression pattern describing valid names for",
|
||||
"variables with global scope (but internal semantic)"
|
||||
],
|
||||
"internal_var_pattern": "_[A-Z][0-9A-Z_]+",
|
||||
"_help_local_var_pattern": [
|
||||
"regular expression pattern describing valid names for",
|
||||
"variables with local scope"
|
||||
],
|
||||
"local_var_pattern": "[a-z][a-z0-9_]+",
|
||||
"_help_private_var_pattern": [
|
||||
"regular expression pattern describing valid names for",
|
||||
"privatedirectory variables"
|
||||
],
|
||||
"private_var_pattern": "_[0-9a-z_]+",
|
||||
"_help_public_var_pattern": [
|
||||
"regular expression pattern describing valid names for public",
|
||||
"directory variables"
|
||||
],
|
||||
"public_var_pattern": "[A-Z][0-9A-Z_]+",
|
||||
"_help_argument_var_pattern": [
|
||||
"regular expression pattern describing valid names for",
|
||||
"function/macro arguments and loop variables."
|
||||
],
|
||||
"argument_var_pattern": "[a-z][a-z0-9_]+",
|
||||
"_help_keyword_pattern": [
|
||||
"regular expression pattern describing valid names for",
|
||||
"keywords used in functions or macros"
|
||||
],
|
||||
"keyword_pattern": "[A-Z][0-9A-Z_]+",
|
||||
"_help_max_conditionals_custom_parser": [
|
||||
"In the heuristic for C0201, how many conditionals to match",
|
||||
"within a loop in before considering the loop a parser."
|
||||
],
|
||||
"max_conditionals_custom_parser": 2,
|
||||
"_help_min_statement_spacing": [
|
||||
"Require at least this many newlines between statements"
|
||||
],
|
||||
"min_statement_spacing": 1,
|
||||
"_help_max_statement_spacing": [
|
||||
"Require no more than this many newlines between statements"
|
||||
],
|
||||
"max_statement_spacing": 2,
|
||||
"max_returns": 6,
|
||||
"max_branches": 12,
|
||||
"max_arguments": 5,
|
||||
"max_localvars": 15,
|
||||
"max_statements": 50
|
||||
},
|
||||
"_help_encode": "Options affecting file encoding",
|
||||
"encode": {
|
||||
"_help_emit_byteorder_mark": [
|
||||
"If true, emit the unicode byte-order mark (BOM) at the start",
|
||||
"of the file"
|
||||
],
|
||||
"emit_byteorder_mark": false,
|
||||
"_help_input_encoding": [
|
||||
"Specify the encoding of the input file. Defaults to utf-8"
|
||||
],
|
||||
"input_encoding": "utf-8",
|
||||
"_help_output_encoding": [
|
||||
"Specify the encoding of the output file. Defaults to utf-8.",
|
||||
"Note that cmake only claims to support utf-8 so be careful",
|
||||
"when using anything else"
|
||||
],
|
||||
"output_encoding": "utf-8"
|
||||
},
|
||||
"_help_misc": "Miscellaneous configurations options.",
|
||||
"misc": {
|
||||
"_help_per_command": [
|
||||
"A dictionary containing any per-command configuration",
|
||||
"overrides. Currently only `command_case` is supported."
|
||||
],
|
||||
"per_command": {}
|
||||
}
|
||||
}
|
|
@ -30,3 +30,8 @@
|
|||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
|
||||
build/
|
||||
build_debug/
|
||||
|
||||
.vscode/
|
|
@ -0,0 +1,9 @@
|
|||
[submodule "3rd-party/pybind11"]
|
||||
path = 3rd-party/pybind11
|
||||
url = git@github.com:pybind/pybind11.git
|
||||
[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"]
|
||||
path = 3rd-party/nlohmann_json_cmake_fetchcontent
|
||||
url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git
|
||||
[submodule "3rd-party/googletest"]
|
||||
path = 3rd-party/googletest
|
||||
url = git@github.com:google/googletest.git
|
|
@ -0,0 +1 @@
|
|||
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae
|
|
@ -0,0 +1,73 @@
|
|||
# TODO: check the minimum cmake version
|
||||
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
|
||||
include(CMakeDependentOption)
|
||||
project(InfiniTensor C CXX)
|
||||
|
||||
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||
option(BUILD_TEST "Build tests" OFF)
|
||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF)
|
||||
|
||||
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||
|
||||
find_package(
|
||||
Python
|
||||
COMPONENTS Interpreter Development
|
||||
REQUIRED)
|
||||
find_package(CUDA REQUIRED)
|
||||
# OpenMP
|
||||
find_package(OpenMP)
|
||||
if(OpenMP_C_FOUND)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
endif()
|
||||
if(OpenMP_CXX_FOUND)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
# add_subdirectory(3rd-party/pybind11)
|
||||
# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||
|
||||
include_directories(include)
|
||||
# include_directories(3rd-party/pybind11/include)
|
||||
# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||
|
||||
if(BUILD_TEST)
|
||||
set(BUILD_GMOCK off)
|
||||
set(INSTALL_GTEST off)
|
||||
add_subdirectory(3rd-party/googletest)
|
||||
include_directories(3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
|
||||
# file(GLOB_RECURSE TEST test/*.cc)
|
||||
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
|
||||
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
|
||||
|
||||
add_library(InfiniTensor SHARED ${SRC})
|
||||
# Target
|
||||
# cuda_add_library(it SHARED ${SRC})
|
||||
# cuda_add_cublas_to_target(it) # cublas
|
||||
# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed)
|
||||
|
||||
# # Python bindings
|
||||
# pybind11_add_module(infini MODULE ${FFI})
|
||||
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||
|
||||
if(BUILD_TEST)
|
||||
enable_testing()
|
||||
# Build all tests file( GLOB TEST_SOURCES test/test_sg2bmm.cc )
|
||||
file(GLOB_RECURSE TEST_SOURCES test/*.cc)
|
||||
foreach(testsourcefile ${TEST_SOURCES})
|
||||
get_filename_component(testname ${testsourcefile} NAME_WE)
|
||||
add_executable(${testname} ${testsourcefile})
|
||||
target_link_libraries(${testname} InfiniTensor gtest_main gtest)
|
||||
add_test(NAME ${testname} COMMAND ${testname})
|
||||
endforeach(testsourcefile ${TEST_SOURCES})
|
||||
endif()
|
|
@ -0,0 +1,32 @@
|
|||
#pragma once
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace it {
|
||||
using std::list;
|
||||
using std::map;
|
||||
using std::nullopt;
|
||||
using std::optional;
|
||||
using std::pair;
|
||||
using std::set;
|
||||
using std::string;
|
||||
using std::tie;
|
||||
using std::to_string;
|
||||
using std::tuple;
|
||||
using std::unordered_map;
|
||||
using std::vector;
|
||||
|
||||
// Aliases
|
||||
using dtype = float;
|
||||
|
||||
} // namespace it
|
|
@ -0,0 +1,36 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
class GraphNode : public Object {
|
||||
protected:
|
||||
TensorVec tensors;
|
||||
TensorVec inputs;
|
||||
TensorVec outputs;
|
||||
OpVec ops;
|
||||
|
||||
public:
|
||||
// Graph(OpVec oplist);
|
||||
string toString() const override;
|
||||
|
||||
void addOp(Operator op) { ops.push_back(op); };
|
||||
const TensorVec &getTensors() const { return tensors; }
|
||||
const TensorVec &getInputs() const { return inputs; }
|
||||
const TensorVec &getOutputs() const { return outputs; }
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
// TensorVec &getInputs();
|
||||
// TensorVec &getOutputs();
|
||||
|
||||
void updateConnection();
|
||||
|
||||
// TODO
|
||||
// bool compute();
|
||||
|
||||
// TODO: move to another class
|
||||
// bool exportOnnx(const char *path);
|
||||
// bool importOnnx(const char *net);
|
||||
};
|
||||
|
||||
} // namespace it
|
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
using GuidBaseType = int;
|
||||
|
||||
class Guid {
|
||||
private:
|
||||
GuidBaseType guid;
|
||||
|
||||
private:
|
||||
GuidBaseType generateGuid() {
|
||||
static GuidBaseType guidCnt = 0;
|
||||
return ++guidCnt;
|
||||
}
|
||||
|
||||
public:
|
||||
Guid() { guid = generateGuid(); }
|
||||
Guid(const Guid &rhs) { guid = generateGuid(); }
|
||||
Guid &operator=(const Guid &rhs) {
|
||||
guid = generateGuid();
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator GuidBaseType() const { return guid; }
|
||||
};
|
||||
|
||||
class Object {
|
||||
protected:
|
||||
Guid guid;
|
||||
|
||||
public:
|
||||
virtual ~Object(){};
|
||||
virtual string toString() const = 0;
|
||||
void print() { std::cout << toString() << std::endl; }
|
||||
Guid getGuid() const { return guid; }
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
|
||||
os << obj.toString();
|
||||
return os;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
#pragma once
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
class OperatorNode : public Object {
|
||||
protected:
|
||||
// OpType type;
|
||||
TensorVec inputs;
|
||||
TensorVec outputs;
|
||||
// vector<WRef<Operator>> predecessors;
|
||||
// vector<WRef<Operator>> successors;
|
||||
public:
|
||||
OperatorNode(TensorVec inputs, TensorVec outputs)
|
||||
: inputs(inputs), outputs(outputs) {}
|
||||
string toString() const override;
|
||||
// Operator(TensorVec inputs) : inputs(inputs) {}
|
||||
|
||||
virtual ~OperatorNode() {}
|
||||
};
|
||||
} // namespace it
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include <functional> // hash
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
namespace it {
|
||||
|
||||
template <typename T> using Ref = std::shared_ptr<T>;
|
||||
template <typename T> using WRef = std::weak_ptr<T>;
|
||||
|
||||
template <typename T> struct is_ref : std::false_type {};
|
||||
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
|
||||
template <typename T> struct is_ref<WRef<T>> : std::true_type {};
|
||||
|
||||
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
|
||||
static_assert(is_ref<T>::value == false, "Ref should not be nested");
|
||||
return std::make_shared<T>(std::forward<Params>(params)...);
|
||||
}
|
||||
|
||||
template <class T, class U,
|
||||
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
|
||||
Ref<T> as(const Ref<U> &ref) {
|
||||
return std::dynamic_pointer_cast<T>(ref);
|
||||
}
|
||||
|
||||
} // namespace it
|
|
@ -0,0 +1,368 @@
|
|||
#pragma once
|
||||
#include "core/object.h"
|
||||
#include "core/ref.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
// class Tensor;
|
||||
class TensorBaseNode;
|
||||
class OperatorNode;
|
||||
class GraphNode;
|
||||
|
||||
using TensorBase = Ref<TensorBaseNode>;
|
||||
using Operator = Ref<OperatorNode>;
|
||||
using Graph = Ref<GraphNode>;
|
||||
|
||||
using TensorVec = vector<TensorBase>;
|
||||
using OpVec = vector<Operator>;
|
||||
|
||||
// using TensorMap = std::map<size_t, Tensor *>;
|
||||
// using OpMap = std::map<size_t, Operator *>;
|
||||
using VType = uint32_t;
|
||||
// using SplittingPoints = std::vector<std::vector<int>>;
|
||||
|
||||
class TensorBaseNode : public Object {
|
||||
public:
|
||||
enum DataType {
|
||||
Float32,
|
||||
Int32,
|
||||
};
|
||||
|
||||
// enum TensorType {
|
||||
// Input,
|
||||
// Weight,
|
||||
// Invalid,
|
||||
// NotCounted,
|
||||
// };
|
||||
|
||||
// // TODO: is more compute state needed?
|
||||
// enum ComputeState {
|
||||
// NotComputed,
|
||||
// // Allocated,
|
||||
// // Initialized,
|
||||
// // ComputedPartial,
|
||||
// ComputedFull,
|
||||
// };
|
||||
|
||||
private:
|
||||
int hid;
|
||||
// uint64_t hash;
|
||||
// Shape shape;
|
||||
int dim;
|
||||
|
||||
vector<WRef<TensorBase>> inputOf;
|
||||
WRef<TensorBase> outputOf;
|
||||
Ref<VType> data;
|
||||
DataType dtype;
|
||||
// ComputeState computed;
|
||||
// static int random_seed[256 * 16];
|
||||
// static bool random_inited;
|
||||
|
||||
public:
|
||||
// Tensor(TensorType type = Input, DataType dtype = Float32)
|
||||
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
|
||||
// data(nullptr), dtype(dtype), type(type), computed(NotComputed) {}
|
||||
// Tensor(const Dim &dims, TensorType type = Input, DataType dtype =
|
||||
// Float32)
|
||||
// : guid(generateGuid()), hash(generateHash()), dims(dims),
|
||||
// outputOf(nullptr), data(nullptr), dtype(dtype), type(type),
|
||||
// computed(NotComputed) {
|
||||
// itInit();
|
||||
// }
|
||||
// Tensor(const Tensor &rhs) : Tensor(rhs.dims, rhs.type, rhs.dtype) {
|
||||
// outputOf = nullptr;
|
||||
// data = nullptr;
|
||||
// hash = rhs.hash;
|
||||
// dimPenalty = rhs.dimPenalty;
|
||||
// itInit();
|
||||
// }
|
||||
// Tensor(VType scalar, TensorType type = Weight, DataType dtype = Float32)
|
||||
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
|
||||
// data(nullptr), dtype(dtype), type(type), computed(ComputedFull) {
|
||||
// assert(size() == 1);
|
||||
// dataMalloc();
|
||||
// data[0] = scalar;
|
||||
// }
|
||||
virtual ~TensorBaseNode() {}
|
||||
string toString() const override;
|
||||
|
||||
// // inputOf and outputOf will not be cloned
|
||||
// Tensor *clone() {
|
||||
// Tensor *t = new Tensor(*this);
|
||||
// return t;
|
||||
// }
|
||||
|
||||
// void clone(Tensor *t) {
|
||||
// dims = t->dims;
|
||||
// dtype = t->dtype;
|
||||
// type = t->type;
|
||||
// hash = t->hash;
|
||||
// dimPenalty = t->dimPenalty;
|
||||
// }
|
||||
|
||||
DataType getDType() const { return dtype; }
|
||||
|
||||
// uint64_t getHash() const { return hash; }
|
||||
|
||||
// void setInputOf(const OpVec &ops) {
|
||||
// inputOf.clear();
|
||||
// for (const auto &op : ops)
|
||||
// inputOf.emplace_back(op);
|
||||
// }
|
||||
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
|
||||
// void setOutputOf(Operator op) { outputOf = op; }
|
||||
|
||||
// const OpVec &getInputOf() { return inputOf; }
|
||||
// Operator *getOutputOf() { return outputOf; }
|
||||
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||
|
||||
// bool dataMalloc() {
|
||||
// if (data == nullptr)
|
||||
// data = new VType[size()];
|
||||
// return data != nullptr;
|
||||
// }
|
||||
|
||||
// const Dim &getDims() const { return dims; }
|
||||
// void setDims(const Dim &dms) { dims = dms; }
|
||||
|
||||
// bool dataRand(int seed = 0) {
|
||||
// if (data == nullptr)
|
||||
// data = new VType[size()];
|
||||
// if (!random_inited)
|
||||
// initFastrand();
|
||||
// // srand(seed);
|
||||
// // faster rand generator; parallel
|
||||
// size_t iEnd = size();
|
||||
// // std::cerr << "Init beginned " << std::endl;
|
||||
// #pragma omp parallel for
|
||||
// for (size_t i = 0; i < iEnd; ++i)
|
||||
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
||||
// 10000;
|
||||
// // std::cerr << "Init finished" << std::endl;
|
||||
// computed = ComputedFull;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// bool setData(VType *dptr) {
|
||||
// if (dptr == nullptr)
|
||||
// return false;
|
||||
// auto sz = size();
|
||||
// #pragma omp parallel for
|
||||
// for (size_t i = 0; i < sz; ++i)
|
||||
// data[i] = dptr[i];
|
||||
// computed = ComputedFull;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// bool setScalar(VType val) {
|
||||
// if (data == nullptr || !dims.empty())
|
||||
// return false;
|
||||
// data[0] = val;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// bool setData(const Dim &ds, VType val) {
|
||||
// if (data == nullptr || ds.size() != dims.size())
|
||||
// return false;
|
||||
// data[getOffset(ds)] = val;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// bool setData(size_t pos, VType val) {
|
||||
// if (data == nullptr || pos >= size())
|
||||
// return false;
|
||||
// data[pos] = val;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
||||
|
||||
// VType getData(const Dim &ds) {
|
||||
// assert(data != nullptr);
|
||||
// auto offset = getOffset(ds);
|
||||
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||
// }
|
||||
|
||||
// VType getData(size_t pos) {
|
||||
// assert(data != nullptr);
|
||||
// assert(pos < size());
|
||||
// return data[pos];
|
||||
// }
|
||||
|
||||
// VType *getDataPtr() const { return data; }
|
||||
|
||||
// size_t getOffset(const Dim &ds) {
|
||||
// auto nDim = ds.size();
|
||||
// assert(dims.size() == nDim);
|
||||
// if (ds.empty())
|
||||
// return 0;
|
||||
// for (size_t i = 0; i < nDim; ++i)
|
||||
// if (ds[i] < 0 || ds[i] >= dims[i])
|
||||
// return (size_t)-1;
|
||||
// size_t idx = ds[0];
|
||||
// size_t dm = 0;
|
||||
// while (++dm < nDim)
|
||||
// idx = idx * dims[dm] + ds[dm];
|
||||
// return idx;
|
||||
// }
|
||||
|
||||
// VType getBroadcastData(const Dim &ds) {
|
||||
// assert(data != nullptr);
|
||||
// auto offset = getBroadcastOffset(ds);
|
||||
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
||||
// }
|
||||
|
||||
// VType getBroadcastData(size_t pos) {
|
||||
// assert(data != nullptr);
|
||||
// return data[pos % size()];
|
||||
// }
|
||||
|
||||
// size_t getBroadcastOffset(const Dim &ds) {
|
||||
// assert(ds.size() >= dims.size());
|
||||
// auto nDim = dims.size();
|
||||
// auto nBroadcastDim = ds.size() - nDim;
|
||||
// for (size_t i = 0; i < nDim; ++i)
|
||||
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
|
||||
// dims[i])
|
||||
// return (size_t)-1;
|
||||
// size_t idx = 0;
|
||||
// for (size_t i = 0; i < nDim; ++i)
|
||||
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
||||
// return idx;
|
||||
// }
|
||||
|
||||
// void itInit() { it = Dim(dims.size(), 0); }
|
||||
|
||||
// void itReset() {
|
||||
// itInit();
|
||||
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||
// it[i] = 0;
|
||||
// }
|
||||
|
||||
// bool itValid() {
|
||||
// if (it.size() != dims.size())
|
||||
// return false;
|
||||
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
||||
// if (it[i] >= dims[i])
|
||||
// return false;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// const Dim &itGet() { return it; }
|
||||
|
||||
// void itNext() {
|
||||
// auto p = it.size() - 1;
|
||||
// it[p] += 1;
|
||||
// while (p >= 1) {
|
||||
// if (it[p] == dims[p]) {
|
||||
// it[p] = 0;
|
||||
// it[--p] += 1;
|
||||
// } else
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
|
||||
// size_t size() const {
|
||||
// size_t sz = 1;
|
||||
// auto dm = dims.size();
|
||||
// while (dm > 0)
|
||||
// sz *= dims[--dm];
|
||||
// return sz;
|
||||
// }
|
||||
|
||||
// TensorType getType() const { return type; }
|
||||
// void setType(TensorType ty) { type = ty; }
|
||||
|
||||
// void print() {
|
||||
// if (type == Invalid) {
|
||||
// std::cout << "Invalid tensor" << std::endl;
|
||||
// return;
|
||||
// }
|
||||
|
||||
// if (data == nullptr || dims.size() == 0) {
|
||||
// std::cout << "Empty tensor" << std::endl;
|
||||
// return;
|
||||
// }
|
||||
|
||||
// // TODO: can be uncommented after tensor's compute type is
|
||||
// correctly set if (computed == NotComputed) {
|
||||
// std::cout << "Uncomputed tensor" << std::endl;
|
||||
// return;
|
||||
// }
|
||||
|
||||
// std::cout << "Tensor: " << guid << std::endl;
|
||||
// auto numDims = dims.size();
|
||||
// auto dimSzVec = std::vector<int>(numDims, 1);
|
||||
// dimSzVec[numDims - 1] = dims[numDims - 1];
|
||||
// for (int i = numDims - 1; i != 0; --i)
|
||||
// dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1];
|
||||
// for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
||||
// for (size_t j = 0; j < numDims; ++j) {
|
||||
// if (i % dimSzVec[j] == 0) {
|
||||
// std::cout << "[";
|
||||
// }
|
||||
// }
|
||||
// std::cout << data[i];
|
||||
// for (size_t j = 0; j < numDims; ++j) {
|
||||
// if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
|
||||
// std::cout << "]";
|
||||
// }
|
||||
// }
|
||||
// if (i != size() - 1)
|
||||
// std::cout << ", ";
|
||||
// if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] -
|
||||
// 1)
|
||||
// std::cout << std::endl;
|
||||
// }
|
||||
// }
|
||||
|
||||
// static inline void initFastrand() {
|
||||
// assert(omp_get_max_threads() <= 256);
|
||||
// // srand(0); // constant seed for test
|
||||
// // align random_seed to avoid false sharing
|
||||
// for (int i = 0; i < 256 * 16; ++i) {
|
||||
// // random_seed[i] = rand();
|
||||
// // constant random seed for test
|
||||
// random_seed[i] = i;
|
||||
// }
|
||||
// random_inited = true;
|
||||
// }
|
||||
|
||||
// static inline int fastrand(int &g_seed) {
|
||||
// g_seed = (214013 * g_seed + 2531011);
|
||||
// return (g_seed >> 16) & 0x7FFF;
|
||||
// }
|
||||
|
||||
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
||||
// assert(!splittingPoints.empty());
|
||||
// return &splittingPoints;
|
||||
// }
|
||||
|
||||
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
||||
// assert(!value.empty());
|
||||
// splittingPoints = value;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// void printSplittingPoints() {
|
||||
// if (splittingPoints.empty())
|
||||
// printf("Empty SplittingPoints");
|
||||
// else {
|
||||
// printf("[");
|
||||
// for (auto &vs : splittingPoints) {
|
||||
// printf("[");
|
||||
// for (auto v : vs)
|
||||
// printf("%2d,", v);
|
||||
// printf("],");
|
||||
// }
|
||||
// printf("]");
|
||||
// }
|
||||
// }
|
||||
|
||||
// void initSplittingPoints() {
|
||||
// splittingPoints.resize(getDims().size()); }
|
||||
|
||||
// void printShape();
|
||||
};
|
||||
|
||||
} // namespace it
|
|
@ -0,0 +1,3 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
#include "gtest/gtest.h"
|
|
@ -0,0 +1,15 @@
|
|||
#include "core/graph.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
void GraphNode::updateConnection() {
|
||||
// TODO
|
||||
}
|
||||
|
||||
string GraphNode::toString() const {
|
||||
std::ostringstream oss;
|
||||
oss << "GraphNode: ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace it
|
|
@ -0,0 +1,11 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
string OperatorNode::toString() const {
|
||||
std::ostringstream oss;
|
||||
oss << "Operator: ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace it
|
|
@ -0,0 +1,8 @@
|
|||
#include <core/tensor.h>
|
||||
namespace it {
|
||||
|
||||
string TensorBaseNode::toString() const {
|
||||
return "TensorBaseNode " + std::to_string(guid);
|
||||
}
|
||||
|
||||
}; // namespace it
|
|
@ -0,0 +1,12 @@
|
|||
#include "core/graph.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace it {
|
||||
|
||||
TEST(Graph, build) {
|
||||
Graph g = make_ref<GraphNode>();
|
||||
g->addOp(make_ref<OperatorNode>(TensorVec{}, TensorVec{}));
|
||||
g->print();
|
||||
}
|
||||
|
||||
} // namespace it
|
Loading…
Reference in New Issue