Add: graph, tensor, and operator

This commit is contained in:
Liyan Zheng 2022-07-31 21:43:26 +08:00
parent b89495a782
commit e6101b0336
19 changed files with 1069 additions and 0 deletions

90
.clang-format Normal file
View File

@ -0,0 +1,90 @@
---
Language: Cpp
# BasedOnStyle: LLVM
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: false
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: false
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Right
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...

311
.cmake-format.json Normal file
View File

@ -0,0 +1,311 @@
{
"_help_parse": "Options affecting listfile parsing",
"parse": {
"_help_additional_commands": [
"Specify structure for custom cmake functions"
],
"additional_commands": {
"foo": {
"flags": [
"BAR",
"BAZ"
],
"kwargs": {
"HEADERS": "*",
"SOURCES": "*",
"DEPENDS": "*"
}
}
},
"_help_override_spec": [
"Override configurations per-command where available"
],
"override_spec": {},
"_help_vartags": [
"Specify variable tags."
],
"vartags": [],
"_help_proptags": [
"Specify property tags."
],
"proptags": []
},
"_help_format": "Options affecting formatting.",
"format": {
"_help_disable": [
"Disable formatting entirely, making cmake-format a no-op"
],
"disable": false,
"_help_line_width": [
"How wide to allow formatted cmake files"
],
"line_width": 1500,
"_help_tab_size": [
"How many spaces to tab for indent"
],
"tab_size": 2,
"_help_use_tabchars": [
"If true, lines are indented using tab characters (utf-8",
"0x09) instead of <tab_size> space characters (utf-8 0x20).",
"In cases where the layout would require a fractional tab",
"character, the behavior of the fractional indentation is",
"governed by <fractional_tab_policy>"
],
"use_tabchars": false,
"_help_fractional_tab_policy": [
"If <use_tabchars> is True, then the value of this variable",
"indicates how fractional indentions are handled during",
"whitespace replacement. If set to 'use-space', fractional",
"indentation is left as spaces (utf-8 0x20). If set to",
"`round-up` fractional indentation is replaced with a single",
"tab character (utf-8 0x09) effectively shifting the column",
"to the next tabstop"
],
"fractional_tab_policy": "use-space",
"_help_max_subgroups_hwrap": [
"If an argument group contains more than this many sub-groups",
"(parg or kwarg groups) then force it to a vertical layout."
],
"max_subgroups_hwrap": 2,
"_help_max_pargs_hwrap": [
"If a positional argument group contains more than this many",
"arguments, then force it to a vertical layout."
],
"max_pargs_hwrap": 6,
"_help_max_rows_cmdline": [
"If a cmdline positional group consumes more than this many",
"lines without nesting, then invalidate the layout (and nest)"
],
"max_rows_cmdline": 2,
"_help_separate_ctrl_name_with_space": [
"If true, separate flow control names from their parentheses",
"with a space"
],
"separate_ctrl_name_with_space": false,
"_help_separate_fn_name_with_space": [
"If true, separate function names from parentheses with a",
"space"
],
"separate_fn_name_with_space": false,
"_help_dangle_parens": [
"If a statement is wrapped to more than one line, than dangle",
"the closing parenthesis on its own line."
],
"dangle_parens": false,
"_help_dangle_align": [
"If the trailing parenthesis must be 'dangled' on its on",
"line, then align it to this reference: `prefix`: the start",
"of the statement, `prefix-indent`: the start of the",
"statement, plus one indentation level, `child`: align to",
"the column of the arguments"
],
"dangle_align": "prefix",
"_help_min_prefix_chars": [
"If the statement spelling length (including space and",
"parenthesis) is smaller than this amount, then force reject",
"nested layouts."
],
"min_prefix_chars": 4,
"_help_max_prefix_chars": [
"If the statement spelling length (including space and",
"parenthesis) is larger than the tab width by more than this",
"amount, then force reject un-nested layouts."
],
"max_prefix_chars": 10,
"_help_max_lines_hwrap": [
"If a candidate layout is wrapped horizontally but it exceeds",
"this many lines, then reject the layout."
],
"max_lines_hwrap": 2,
"_help_line_ending": [
"What style line endings to use in the output."
],
"line_ending": "unix",
"_help_command_case": [
"Format command names consistently as 'lower' or 'upper' case"
],
"command_case": "canonical",
"_help_keyword_case": [
"Format keywords consistently as 'lower' or 'upper' case"
],
"keyword_case": "unchanged",
"_help_always_wrap": [
"A list of command names which should always be wrapped"
],
"always_wrap": [],
"_help_enable_sort": [
"If true, the argument lists which are known to be sortable",
"will be sorted lexicographicall"
],
"enable_sort": true,
"_help_autosort": [
"If true, the parsers may infer whether or not an argument",
"list is sortable (without annotation)."
],
"autosort": false,
"_help_require_valid_layout": [
"By default, if cmake-format cannot successfully fit",
"everything into the desired linewidth it will apply the",
"last, most aggressive attempt that it made. If this flag is",
"True, however, cmake-format will print error, exit with non-",
"zero status code, and write-out nothing"
],
"require_valid_layout": false,
"_help_layout_passes": [
"A dictionary mapping layout nodes to a list of wrap",
"decisions. See the documentation for more information."
],
"layout_passes": {}
},
"_help_markup": "Options affecting comment reflow and formatting.",
"markup": {
"_help_bullet_char": [
"What character to use for bulleted lists"
],
"bullet_char": "*",
"_help_enum_char": [
"What character to use as punctuation after numerals in an",
"enumerated list"
],
"enum_char": ".",
"_help_first_comment_is_literal": [
"If comment markup is enabled, don't reflow the first comment",
"block in each listfile. Use this to preserve formatting of",
"your copyright/license statements."
],
"first_comment_is_literal": false,
"_help_literal_comment_pattern": [
"If comment markup is enabled, don't reflow any comment block",
"which matches this (regex) pattern. Default is `None`",
"(disabled)."
],
"literal_comment_pattern": ".*",
"_help_fence_pattern": [
"Regular expression to match preformat fences in comments",
"default= ``r'^\\s*([`~]{3}[`~]*)(.*)$'``"
],
"fence_pattern": "^\\s*([`~]{3}[`~]*)(.*)$",
"_help_ruler_pattern": [
"Regular expression to match rulers in comments default=",
"``r'^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'``"
],
"ruler_pattern": "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$",
"_help_explicit_trailing_pattern": [
"If a comment line matches starts with this pattern then it",
"is explicitly a trailing comment for the preceding argument.",
"Default is '#<'"
],
"explicit_trailing_pattern": "#<",
"_help_hashruler_min_length": [
"If a comment line starts with at least this many consecutive",
"hash characters, then don't lstrip() them off. This allows",
"for lazy hash rulers where the first hash char is not",
"separated by space"
],
"hashruler_min_length": 10,
"_help_canonicalize_hashrulers": [
"If true, then insert a space between the first hash char and",
"remaining hash chars in a hash ruler, and normalize its",
"length to fill the column"
],
"canonicalize_hashrulers": true,
"_help_enable_markup": [
"enable comment markup parsing and reflow"
],
"enable_markup": true
},
"_help_lint": "Options affecting the linter",
"lint": {
"_help_disabled_codes": [
"a list of lint codes to disable"
],
"disabled_codes": [],
"_help_function_pattern": [
"regular expression pattern describing valid function names"
],
"function_pattern": "[0-9a-z_]+",
"_help_macro_pattern": [
"regular expression pattern describing valid macro names"
],
"macro_pattern": "[0-9A-Z_]+",
"_help_global_var_pattern": [
"regular expression pattern describing valid names for",
"variables with global (cache) scope"
],
"global_var_pattern": "[A-Z][0-9A-Z_]+",
"_help_internal_var_pattern": [
"regular expression pattern describing valid names for",
"variables with global scope (but internal semantic)"
],
"internal_var_pattern": "_[A-Z][0-9A-Z_]+",
"_help_local_var_pattern": [
"regular expression pattern describing valid names for",
"variables with local scope"
],
"local_var_pattern": "[a-z][a-z0-9_]+",
"_help_private_var_pattern": [
"regular expression pattern describing valid names for",
"privatedirectory variables"
],
"private_var_pattern": "_[0-9a-z_]+",
"_help_public_var_pattern": [
"regular expression pattern describing valid names for public",
"directory variables"
],
"public_var_pattern": "[A-Z][0-9A-Z_]+",
"_help_argument_var_pattern": [
"regular expression pattern describing valid names for",
"function/macro arguments and loop variables."
],
"argument_var_pattern": "[a-z][a-z0-9_]+",
"_help_keyword_pattern": [
"regular expression pattern describing valid names for",
"keywords used in functions or macros"
],
"keyword_pattern": "[A-Z][0-9A-Z_]+",
"_help_max_conditionals_custom_parser": [
"In the heuristic for C0201, how many conditionals to match",
"within a loop in before considering the loop a parser."
],
"max_conditionals_custom_parser": 2,
"_help_min_statement_spacing": [
"Require at least this many newlines between statements"
],
"min_statement_spacing": 1,
"_help_max_statement_spacing": [
"Require no more than this many newlines between statements"
],
"max_statement_spacing": 2,
"max_returns": 6,
"max_branches": 12,
"max_arguments": 5,
"max_localvars": 15,
"max_statements": 50
},
"_help_encode": "Options affecting file encoding",
"encode": {
"_help_emit_byteorder_mark": [
"If true, emit the unicode byte-order mark (BOM) at the start",
"of the file"
],
"emit_byteorder_mark": false,
"_help_input_encoding": [
"Specify the encoding of the input file. Defaults to utf-8"
],
"input_encoding": "utf-8",
"_help_output_encoding": [
"Specify the encoding of the output file. Defaults to utf-8.",
"Note that cmake only claims to support utf-8 so be careful",
"when using anything else"
],
"output_encoding": "utf-8"
},
"_help_misc": "Miscellaneous configurations options.",
"misc": {
"_help_per_command": [
"A dictionary containing any per-command configuration",
"overrides. Currently only `command_case` is supported."
],
"per_command": {}
}
}

5
.gitignore vendored
View File

@ -30,3 +30,8 @@
*.exe *.exe
*.out *.out
*.app *.app
build/
build_debug/
.vscode/

9
.gitmodules vendored Normal file
View File

@ -0,0 +1,9 @@
[submodule "3rd-party/pybind11"]
path = 3rd-party/pybind11
url = git@github.com:pybind/pybind11.git
[submodule "3rd-party/nlohmann_json_cmake_fetchcontent"]
path = 3rd-party/nlohmann_json_cmake_fetchcontent
url = git@github.com:ArthurSonzogni/nlohmann_json_cmake_fetchcontent.git
[submodule "3rd-party/googletest"]
path = 3rd-party/googletest
url = git@github.com:google/googletest.git

1
3rd-party/googletest vendored Submodule

@ -0,0 +1 @@
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929

@ -0,0 +1 @@
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756

1
3rd-party/pybind11 vendored Submodule

@ -0,0 +1 @@
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae

73
CMakeLists.txt Normal file
View File

@ -0,0 +1,73 @@
# TODO: check the minimum cmake version
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
include(CMakeDependentOption)
project(InfiniTensor C CXX)
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
option(BUILD_TEST "Build tests" OFF)
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF)
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
find_package(
Python
COMPONENTS Interpreter Development
REQUIRED)
find_package(CUDA REQUIRED)
# OpenMP
find_package(OpenMP)
if(OpenMP_C_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
endif()
if(OpenMP_CXX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
# add_subdirectory(3rd-party/pybind11)
# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
include_directories(include)
# include_directories(3rd-party/pybind11/include)
# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
if(BUILD_TEST)
set(BUILD_GMOCK off)
set(INSTALL_GTEST off)
add_subdirectory(3rd-party/googletest)
include_directories(3rd-party/googletest/googletest/include)
endif()
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
# file(GLOB_RECURSE TEST test/*.cc)
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
add_library(InfiniTensor SHARED ${SRC})
# Target
# cuda_add_library(it SHARED ${SRC})
# cuda_add_cublas_to_target(it) # cublas
# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed)
# # Python bindings
# pybind11_add_module(infini MODULE ${FFI})
# target_link_libraries(infini PRIVATE infini_cpp)
if(BUILD_TEST)
enable_testing()
# Build all tests file( GLOB TEST_SOURCES test/test_sg2bmm.cc )
file(GLOB_RECURSE TEST_SOURCES test/*.cc)
foreach(testsourcefile ${TEST_SOURCES})
get_filename_component(testname ${testsourcefile} NAME_WE)
add_executable(${testname} ${testsourcefile})
target_link_libraries(${testname} InfiniTensor gtest_main gtest)
add_test(NAME ${testname} COMMAND ${testname})
endforeach(testsourcefile ${TEST_SOURCES})
endif()

32
include/core/common.h Normal file
View File

@ -0,0 +1,32 @@
#pragma once
#include <cassert>
#include <iostream>
#include <list>
#include <map>
#include <optional>
#include <set>
#include <sstream>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace it {
using std::list;
using std::map;
using std::nullopt;
using std::optional;
using std::pair;
using std::set;
using std::string;
using std::tie;
using std::to_string;
using std::tuple;
using std::unordered_map;
using std::vector;
// Aliases
using dtype = float;
} // namespace it

36
include/core/graph.h Normal file
View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
#include "core/tensor.h"
namespace it {
class GraphNode : public Object {
protected:
TensorVec tensors;
TensorVec inputs;
TensorVec outputs;
OpVec ops;
public:
// Graph(OpVec oplist);
string toString() const override;
void addOp(Operator op) { ops.push_back(op); };
const TensorVec &getTensors() const { return tensors; }
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }
const OpVec &getOperators() const { return ops; }
// TensorVec &getInputs();
// TensorVec &getOutputs();
void updateConnection();
// TODO
// bool compute();
// TODO: move to another class
// bool exportOnnx(const char *path);
// bool importOnnx(const char *net);
};
} // namespace it

45
include/core/object.h Normal file
View File

@ -0,0 +1,45 @@
#pragma once
#include "core/common.h"
namespace it {
using GuidBaseType = int;
class Guid {
private:
GuidBaseType guid;
private:
GuidBaseType generateGuid() {
static GuidBaseType guidCnt = 0;
return ++guidCnt;
}
public:
Guid() { guid = generateGuid(); }
Guid(const Guid &rhs) { guid = generateGuid(); }
Guid &operator=(const Guid &rhs) {
guid = generateGuid();
return *this;
}
operator GuidBaseType() const { return guid; }
};
class Object {
protected:
Guid guid;
public:
virtual ~Object(){};
virtual string toString() const = 0;
void print() { std::cout << toString() << std::endl; }
Guid getGuid() const { return guid; }
};
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
os << obj.toString();
return os;
}
}

21
include/core/operator.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
#include "core/tensor.h"
namespace it {
class OperatorNode : public Object {
protected:
// OpType type;
TensorVec inputs;
TensorVec outputs;
// vector<WRef<Operator>> predecessors;
// vector<WRef<Operator>> successors;
public:
OperatorNode(TensorVec inputs, TensorVec outputs)
: inputs(inputs), outputs(outputs) {}
string toString() const override;
// Operator(TensorVec inputs) : inputs(inputs) {}
virtual ~OperatorNode() {}
};
} // namespace it

27
include/core/ref.h Normal file
View File

@ -0,0 +1,27 @@
#pragma once
#include "common.h"
#include <functional> // hash
#include <memory>
#include <type_traits>
namespace it {
template <typename T> using Ref = std::shared_ptr<T>;
template <typename T> using WRef = std::weak_ptr<T>;
template <typename T> struct is_ref : std::false_type {};
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
template <typename T> struct is_ref<WRef<T>> : std::true_type {};
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
static_assert(is_ref<T>::value == false, "Ref should not be nested");
return std::make_shared<T>(std::forward<Params>(params)...);
}
template <class T, class U,
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
Ref<T> as(const Ref<U> &ref) {
return std::dynamic_pointer_cast<T>(ref);
}
} // namespace it

368
include/core/tensor.h Normal file
View File

@ -0,0 +1,368 @@
#pragma once
#include "core/object.h"
#include "core/ref.h"
namespace it {
// class Tensor;
class TensorBaseNode;
class OperatorNode;
class GraphNode;
using TensorBase = Ref<TensorBaseNode>;
using Operator = Ref<OperatorNode>;
using Graph = Ref<GraphNode>;
using TensorVec = vector<TensorBase>;
using OpVec = vector<Operator>;
// using TensorMap = std::map<size_t, Tensor *>;
// using OpMap = std::map<size_t, Operator *>;
using VType = uint32_t;
// using SplittingPoints = std::vector<std::vector<int>>;
class TensorBaseNode : public Object {
public:
enum DataType {
Float32,
Int32,
};
// enum TensorType {
// Input,
// Weight,
// Invalid,
// NotCounted,
// };
// // TODO: is more compute state needed?
// enum ComputeState {
// NotComputed,
// // Allocated,
// // Initialized,
// // ComputedPartial,
// ComputedFull,
// };
private:
int hid;
// uint64_t hash;
// Shape shape;
int dim;
vector<WRef<TensorBase>> inputOf;
WRef<TensorBase> outputOf;
Ref<VType> data;
DataType dtype;
// ComputeState computed;
// static int random_seed[256 * 16];
// static bool random_inited;
public:
// Tensor(TensorType type = Input, DataType dtype = Float32)
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
// data(nullptr), dtype(dtype), type(type), computed(NotComputed) {}
// Tensor(const Dim &dims, TensorType type = Input, DataType dtype =
// Float32)
// : guid(generateGuid()), hash(generateHash()), dims(dims),
// outputOf(nullptr), data(nullptr), dtype(dtype), type(type),
// computed(NotComputed) {
// itInit();
// }
// Tensor(const Tensor &rhs) : Tensor(rhs.dims, rhs.type, rhs.dtype) {
// outputOf = nullptr;
// data = nullptr;
// hash = rhs.hash;
// dimPenalty = rhs.dimPenalty;
// itInit();
// }
// Tensor(VType scalar, TensorType type = Weight, DataType dtype = Float32)
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
// data(nullptr), dtype(dtype), type(type), computed(ComputedFull) {
// assert(size() == 1);
// dataMalloc();
// data[0] = scalar;
// }
virtual ~TensorBaseNode() {}
string toString() const override;
// // inputOf and outputOf will not be cloned
// Tensor *clone() {
// Tensor *t = new Tensor(*this);
// return t;
// }
// void clone(Tensor *t) {
// dims = t->dims;
// dtype = t->dtype;
// type = t->type;
// hash = t->hash;
// dimPenalty = t->dimPenalty;
// }
DataType getDType() const { return dtype; }
// uint64_t getHash() const { return hash; }
// void setInputOf(const OpVec &ops) {
// inputOf.clear();
// for (const auto &op : ops)
// inputOf.emplace_back(op);
// }
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
// void setOutputOf(Operator op) { outputOf = op; }
// const OpVec &getInputOf() { return inputOf; }
// Operator *getOutputOf() { return outputOf; }
// std::pair<Operator *, int> getOutputOfWithIndex();
// bool dataMalloc() {
// if (data == nullptr)
// data = new VType[size()];
// return data != nullptr;
// }
// const Dim &getDims() const { return dims; }
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {
// if (data == nullptr)
// data = new VType[size()];
// if (!random_inited)
// initFastrand();
// // srand(seed);
// // faster rand generator; parallel
// size_t iEnd = size();
// // std::cerr << "Init beginned " << std::endl;
// #pragma omp parallel for
// for (size_t i = 0; i < iEnd; ++i)
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
// 10000;
// // std::cerr << "Init finished" << std::endl;
// computed = ComputedFull;
// return true;
// }
// bool setData(VType *dptr) {
// if (dptr == nullptr)
// return false;
// auto sz = size();
// #pragma omp parallel for
// for (size_t i = 0; i < sz; ++i)
// data[i] = dptr[i];
// computed = ComputedFull;
// return true;
// }
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
// data[0] = val;
// return true;
// }
// bool setData(const Dim &ds, VType val) {
// if (data == nullptr || ds.size() != dims.size())
// return false;
// data[getOffset(ds)] = val;
// return true;
// }
// bool setData(size_t pos, VType val) {
// if (data == nullptr || pos >= size())
// return false;
// data[pos] = val;
// return true;
// }
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
// VType getData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getData(size_t pos) {
// assert(data != nullptr);
// assert(pos < size());
// return data[pos];
// }
// VType *getDataPtr() const { return data; }
// size_t getOffset(const Dim &ds) {
// auto nDim = ds.size();
// assert(dims.size() == nDim);
// if (ds.empty())
// return 0;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[i] < 0 || ds[i] >= dims[i])
// return (size_t)-1;
// size_t idx = ds[0];
// size_t dm = 0;
// while (++dm < nDim)
// idx = idx * dims[dm] + ds[dm];
// return idx;
// }
// VType getBroadcastData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getBroadcastOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getBroadcastData(size_t pos) {
// assert(data != nullptr);
// return data[pos % size()];
// }
// size_t getBroadcastOffset(const Dim &ds) {
// assert(ds.size() >= dims.size());
// auto nDim = dims.size();
// auto nBroadcastDim = ds.size() - nDim;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
// dims[i])
// return (size_t)-1;
// size_t idx = 0;
// for (size_t i = 0; i < nDim; ++i)
// idx = idx * dims[i] + ds[nBroadcastDim + i];
// return idx;
// }
// void itInit() { it = Dim(dims.size(), 0); }
// void itReset() {
// itInit();
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// it[i] = 0;
// }
// bool itValid() {
// if (it.size() != dims.size())
// return false;
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// if (it[i] >= dims[i])
// return false;
// return true;
// }
// const Dim &itGet() { return it; }
// void itNext() {
// auto p = it.size() - 1;
// it[p] += 1;
// while (p >= 1) {
// if (it[p] == dims[p]) {
// it[p] = 0;
// it[--p] += 1;
// } else
// break;
// }
// }
// size_t size() const {
// size_t sz = 1;
// auto dm = dims.size();
// while (dm > 0)
// sz *= dims[--dm];
// return sz;
// }
// TensorType getType() const { return type; }
// void setType(TensorType ty) { type = ty; }
// void print() {
// if (type == Invalid) {
// std::cout << "Invalid tensor" << std::endl;
// return;
// }
// if (data == nullptr || dims.size() == 0) {
// std::cout << "Empty tensor" << std::endl;
// return;
// }
// // TODO: can be uncommented after tensor's compute type is
// correctly set if (computed == NotComputed) {
// std::cout << "Uncomputed tensor" << std::endl;
// return;
// }
// std::cout << "Tensor: " << guid << std::endl;
// auto numDims = dims.size();
// auto dimSzVec = std::vector<int>(numDims, 1);
// dimSzVec[numDims - 1] = dims[numDims - 1];
// for (int i = numDims - 1; i != 0; --i)
// dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1];
// for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
// for (size_t j = 0; j < numDims; ++j) {
// if (i % dimSzVec[j] == 0) {
// std::cout << "[";
// }
// }
// std::cout << data[i];
// for (size_t j = 0; j < numDims; ++j) {
// if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
// std::cout << "]";
// }
// }
// if (i != size() - 1)
// std::cout << ", ";
// if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] -
// 1)
// std::cout << std::endl;
// }
// }
// static inline void initFastrand() {
// assert(omp_get_max_threads() <= 256);
// // srand(0); // constant seed for test
// // align random_seed to avoid false sharing
// for (int i = 0; i < 256 * 16; ++i) {
// // random_seed[i] = rand();
// // constant random seed for test
// random_seed[i] = i;
// }
// random_inited = true;
// }
// static inline int fastrand(int &g_seed) {
// g_seed = (214013 * g_seed + 2531011);
// return (g_seed >> 16) & 0x7FFF;
// }
// std::vector<std::vector<int>> const *getSplittingPoints() const {
// assert(!splittingPoints.empty());
// return &splittingPoints;
// }
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
// assert(!value.empty());
// splittingPoints = value;
// return true;
// }
// void printSplittingPoints() {
// if (splittingPoints.empty())
// printf("Empty SplittingPoints");
// else {
// printf("[");
// for (auto &vs : splittingPoints) {
// printf("[");
// for (auto v : vs)
// printf("%2d,", v);
// printf("],");
// }
// printf("]");
// }
// }
// void initSplittingPoints() {
// splittingPoints.resize(getDims().size()); }
// void printShape();
};
} // namespace it

3
include/test.h Normal file
View File

@ -0,0 +1,3 @@
#pragma once
#include "core/common.h"
#include "gtest/gtest.h"

15
src/core/graph.cc Normal file
View File

@ -0,0 +1,15 @@
#include "core/graph.h"
namespace it {
void GraphNode::updateConnection() {
// TODO
}
string GraphNode::toString() const {
std::ostringstream oss;
oss << "GraphNode: ";
return oss.str();
}
} // namespace it

11
src/core/operator.cc Normal file
View File

@ -0,0 +1,11 @@
#include "core/operator.h"
namespace it {
string OperatorNode::toString() const {
std::ostringstream oss;
oss << "Operator: ";
return oss.str();
}
} // namespace it

8
src/core/tensor.cc Normal file
View File

@ -0,0 +1,8 @@
#include <core/tensor.h>
namespace it {
string TensorBaseNode::toString() const {
return "TensorBaseNode " + std::to_string(guid);
}
}; // namespace it

12
test/core/test_graph.cc Normal file
View File

@ -0,0 +1,12 @@
#include "core/graph.h"
#include "test.h"
namespace it {
TEST(Graph, build) {
Graph g = make_ref<GraphNode>();
g->addOp(make_ref<OperatorNode>(TensorVec{}, TensorVec{}));
g->print();
}
} // namespace it