diff --git a/include/core/common.h b/include/core/common.h index 0fe7344e..46eb6922 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -16,6 +16,7 @@ namespace infini { using std::list; using std::map; +using std::optional; using std::pair; using std::set; using std::string; @@ -27,7 +28,7 @@ using std::vector; // Aliases using dtype = float; -using HashType = size_t; // compatible with std::hash +using HashType = uint64_t; // compatible with std::hash // Metaprogramming utilities #define _CAT(A, B) A##B diff --git a/include/core/graph.h b/include/core/graph.h index 9c87310a..8cb8cea1 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -5,7 +5,7 @@ namespace infini { // TODO: graph should be attached to a context -class GraphNode : public Object { +class GraphObj : public Object { protected: TensorVec tensors; TensorVec inputs; @@ -16,7 +16,28 @@ class GraphNode : public Object { // Graph(OpVec oplist); string toString() const override; - void addOp(Operator op) { ops.push_back(op); }; + Tensor addTensor(Shape dim, DataType dtype = DataType::Int32); + + /** + * @brief Add an operator and create its outputs. Output tensor arguments + * should be empty Refs (e.g., nullptr). + */ + template Ref addOp(Args &&...args) { + Ref op = make_ref(this, std::forward(args)...); + ops.push_back(op); + return op; + } + + /** + * @brief Add an operator with its outputs specified. + */ + template + Ref addOpWithOutputs(Args &&...args) { + Ref op = make_ref(nullptr, std::forward(args)...); + ops.push_back(op); + return op; + } + const TensorVec &getTensors() const { return tensors; } const TensorVec &getInputs() const { return inputs; } const TensorVec &getOutputs() const { return outputs; } @@ -24,12 +45,6 @@ class GraphNode : public Object { // TensorVec &getInputs(); // TensorVec &getOutputs(); - Tensor addTensor(Shape dim, DataType dtype = DataType::Int32) { - Tensor tensor = make_ref(dim, dtype); - tensors.emplace_back(tensor); - return tensor; - } - void dataMalloc(); private: diff --git a/include/core/hash.h b/include/core/hash.h new file mode 100644 index 00000000..3963af91 --- /dev/null +++ b/include/core/hash.h @@ -0,0 +1,18 @@ +#include "core/common.h" + +namespace infini { + +inline HashType hashAppend(HashType a, HashType b) { + return (a * 10000019 + b * 10000079) % 2147483647; +} + +// inline HashType hashPack(HashType x) { return (x * 10000103) % 2147483647; } + +template inline HashType hashVector(const vector &vec) { + HashType ret = 0; + for (auto v : vec) + ret = hashAppend(ret, v); + return ret; +} + +} // namespace infini \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h index a24f240d..3be1c57f 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -127,9 +127,7 @@ struct OpPerfKey { } }; -class OperatorNode : public Object { - friend class Kernel; - +class OperatorObj : public Object { protected: OpType type; TensorVec inputs; @@ -138,10 +136,24 @@ class OperatorNode : public Object { // vector> successors; public: - OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs) + OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs) : type(opType), inputs(inputs), outputs(outputs) {} - virtual vector computeShape() const = 0; - virtual OpPerfKey getOpPerfKey() const = 0; + virtual optional> + inferShape(const TensorVec &inputs) const = 0; + /** + * @brief Constructs outputs (if requried) and check whether the operator is + * valid. + * + * @param graph If graph is not nullptr, outputs should be created in this + * function. + */ + bool checkValid(GraphObj *graph); + OpPerfKey getOpPerfKey() const; + /** + * @brief Hash operator attributes. Input and output shapes are not + * considered. + */ + HashType hash() const; public: // check Op type bool isLinearOp() const; @@ -167,8 +179,22 @@ class OperatorNode : public Object { virtual int numInputs() const = 0; virtual int numOutputs() const = 0; - virtual HashType hash() const { IT_TODO_HALT(); } - virtual HashType hashWithShape() const { IT_TODO_HALT(); } + + protected: + optional> inferShape() const; + + private: + /** + * @brief The returned vector includes operator attributes, such as paddings + * in Conv and transpose in Matmul. However, the input and output shapes are + * not taken into consideration. + */ + virtual vector getOpAttrVector() const { IT_TODO_HALT(); } + /** + * @brief Besides operator attributes, the returned vector includes input + * and output shapes. + */ + virtual vector getWorkloadVector() const { IT_TODO_HALT(); } }; } // namespace infini diff --git a/include/core/tensor.h b/include/core/tensor.h index 67544753..e4bbc6c2 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -6,13 +6,13 @@ namespace infini { // TODO: how to deal with this using ShapeElem = int; using Shape = vector; -class TensorNode : public TensorBaseNode { +class TensorObj : public TensorBaseObj { private: Shape shape; public: - TensorNode(const Shape &shape, DataType dtype); - virtual ~TensorNode() {} + TensorObj(const Shape &shape, DataType dtype); + virtual ~TensorObj() {} string toString() const override; size_t size() const; @@ -21,7 +21,7 @@ class TensorNode : public TensorBaseNode { Shape getDims() const { return shape; } size_t getOffset(const Shape &ds) const; - using TensorBaseNode::getData; + using TensorBaseObj::getData; VType getData(const Shape &pos) const; void copyData(VType *dptr); void printData() const; diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index cafea062..eefd300f 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -5,15 +5,15 @@ namespace infini { // class Tensor; -class TensorBaseNode; -class TensorNode; -class OperatorNode; -class GraphNode; +class TensorBaseObj; +class TensorObj; +class OperatorObj; +class GraphObj; -using TensorBase = Ref; -using Tensor = Ref; -using Operator = Ref; -using Graph = Ref; +using TensorBase = Ref; +using Tensor = Ref; +using Operator = Ref; +using Graph = Ref; using TensorVec = vector; using OpVec = vector; @@ -25,7 +25,7 @@ enum class DataType { Int32, }; -class TensorBaseNode : public Object { +class TensorBaseObj : public Object { public: // enum TensorType { // Input, @@ -38,8 +38,8 @@ class TensorBaseNode : public Object { int dim; DataType dtype; - vector> inputOf; - WRef outputOf; + vector> inputOf; + WRef outputOf; // TODO: Ref -> Ref Ref data; // ComputeState computed; @@ -47,8 +47,8 @@ class TensorBaseNode : public Object { // static bool random_inited; public: - TensorBaseNode(int dim, DataType dtype); - virtual ~TensorBaseNode() {} + TensorBaseObj(int dim, DataType dtype); + virtual ~TensorBaseObj() {} Ref getDataPtr() const { return data; } VType getData(size_t offset) const; diff --git a/include/operators/matmul.h b/include/operators/matmul.h index b94dabe0..328756b0 100644 --- a/include/operators/matmul.h +++ b/include/operators/matmul.h @@ -3,26 +3,37 @@ namespace infini { -class MatmulNode : public OperatorNode { +class MatmulObj : public OperatorObj { private: - // InfiniTensor assume a row-major tensor layout. transA=false means default - // dims, true means A should be transposed before matmul. This is in - // oppsite to column-major BLAS. + // InfiniTensor assumes a row-major tensor layout. `transA`=false means + // default dims, true means A should be transposed before matmul. This is in + // oppsite to the column-major BLAS. bool transA, transB; ActType act; - // Auxiliary attributes + // Auxiliary attributes which are not a part of operator attributes. int b, m, n, k; public: - MatmulNode(Tensor A, Tensor B, Tensor C, bool transA = false, - bool transB = false, Tensor bias = nullptr, - ActType act = ActType::None); + /** + * @brief This comments show how operators is defined in InfiniTensor. The + * constructor can create output tensors for the operator or not, which + * depends on `graph`. + * + * @param graph If graph is not empty, create outputs in the constructor. + * Otherwise, check the provided shape with the results of `inferShape` in + * `checkValid`. + * @param C C is the output of Matmul. If outputs are going to be created in + * the constructor, C should be an empty Ref. + */ + MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, + bool transA = false, bool transB = false, Tensor bias = nullptr, + ActType act = ActType::None); std::string toString() const override; - vector computeShape() const override; + optional> inferShape(const TensorVec &inputs) const override; - int numInputs() const override { return 2; } + int numInputs() const override { return 3; } int numOutputs() const override { return 1; } Tensor getBias() const { return inputs[2]; } @@ -34,14 +45,9 @@ class MatmulNode : public OperatorNode { int getN() const { return n; } int getK() const { return k; } - HashType hashWithShape() const override; - OpPerfKey getOpPerfKey() const override; - private: - // Q: whether to check the output? Since we can build an Op first and then - // assure output. - // Fix 1: make shape inference a static method. But OpPerfKey are required. - bool checkValid(const TensorVec &inputs) const; + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; }; } // namespace infini diff --git a/src/core/graph.cc b/src/core/graph.cc index 0f6fb180..f8d85122 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -2,19 +2,25 @@ namespace infini { -void GraphNode::updateConnection() { IT_TODO_HALT(); } +void GraphObj::updateConnection() { IT_TODO_HALT(); } -string GraphNode::toString() const { +string GraphObj::toString() const { std::ostringstream oss; - oss << "GraphNode operators:\n"; + oss << "Graph operators:\n"; for (const auto &op : ops) oss << op << "\n"; return oss.str(); } -void GraphNode::dataMalloc() { +void GraphObj::dataMalloc() { for (auto &tensor : tensors) tensor->dataMalloc(); } +Tensor GraphObj::addTensor(Shape dim, DataType dtype) { + Tensor tensor = make_ref(dim, dtype); + tensors.emplace_back(tensor); + return tensor; +} + } // namespace infini \ No newline at end of file diff --git a/src/core/operator.cc b/src/core/operator.cc index b215cb8a..e81c004b 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -1,32 +1,77 @@ #include "core/operator.h" +#include "core/graph.h" +#include "core/hash.h" namespace infini { -bool OperatorNode::isLinearOp() const { +bool OperatorObj::isLinearOp() const { return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200; } -bool OperatorNode::isElementWiseOp() const { +bool OperatorObj::isElementWiseOp() const { return enum_to_underlying(type) >= 200 && enum_to_underlying(type) < 300; } -bool OperatorNode::isSplitOp() const { return type == OpType::Split; } +bool OperatorObj::isSplitOp() const { return type == OpType::Split; } -bool OperatorNode::isConcatOp() const { return type == OpType::Concat; } +bool OperatorObj::isConcatOp() const { return type == OpType::Concat; } -bool OperatorNode::isComputeOp() const { +bool OperatorObj::isComputeOp() const { return type == OpType::Conv || type == OpType::Matmul || type == OpType::ConvTrans || type == OpType::G2BMM || type == OpType::GBMML; } -bool OperatorNode::isTransposeOp() const { return type == OpType::Transpose; } +bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; } -bool OperatorNode::isReshapeOp() const { return type == OpType::Reshape; } +bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; } -bool OperatorNode::isMemBoundOp() const { +bool OperatorObj::isMemBoundOp() const { return type == OpType::MemBound || type == OpType::Activation || type == OpType::Transpose; } +OpPerfKey OperatorObj::getOpPerfKey() const { + auto workloadVector = getWorkloadVector(); + // Calculate hash of workload, i.e. hash with shape. This is different from + // Operator::hash, which hashes operator attributes and ignores tensor + // shapes. + HashType hash = 0; + hash = hashAppend(hash, enum_to_underlying(type)); + hash = hashAppend(hash, hashVector(workloadVector)); + return OpPerfKey(hash, type, workloadVector); +} + +HashType OperatorObj::hash() const { + HashType hash = 0; + hash = hashAppend(hash, enum_to_underlying(type)); + hash = hashAppend(hash, hashVector(getOpAttrVector())); + return hash; +} + +bool OperatorObj::checkValid(GraphObj *graph) { + auto optShapes = inferShape(); + if (!optShapes) // shape inference failed + return false; + const vector &shapes = *optShapes; + if (shapes.size() != outputs.size()) + return false; + if (graph) { // if graph != nullptr, outputs should be created + for (size_t i = 0; i < outputs.size(); i++) { + IT_ASSERT(!outputs[i]); + outputs[i] = graph->addTensor(shapes[i]); + } + } else { // if graph is not empty, check outputs match inferred shapes + for (size_t i = 0; i < shapes.size(); ++i) { + if (shapes[i] != outputs[i]->getDims()) + return false; + } + } + return true; +} + +optional> OperatorObj::inferShape() const { + return inferShape(inputs); +} + } // namespace infini \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 18460986..41aa0aac 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -1,24 +1,22 @@ #include namespace infini { -TensorNode::TensorNode(const Shape &shape, DataType dtype) - : TensorBaseNode(shape.size(), dtype), shape(shape) {} +TensorObj::TensorObj(const Shape &shape, DataType dtype) + : TensorBaseObj(shape.size(), dtype), shape(shape) {} -void TensorNode::dataMalloc() { +void TensorObj::dataMalloc() { IT_ASSERT(data == nullptr); // initialized to zero data.reset(reinterpret_cast(calloc(size(), sizeof(VType)))); } -VType TensorNode::getData(const Shape &pos) const { +VType TensorObj::getData(const Shape &pos) const { return getData(getOffset(pos)); } -string TensorNode::toString() const { - return "TensorNode " + std::to_string(guid); -} +string TensorObj::toString() const { return "Tensor " + std::to_string(guid); } -size_t TensorNode::getOffset(const Shape &pos) const { +size_t TensorObj::getOffset(const Shape &pos) const { auto nDim = pos.size(); IT_ASSERT(shape.size() == nDim); if (pos.empty()) @@ -32,14 +30,14 @@ size_t TensorNode::getOffset(const Shape &pos) const { return idx; } -size_t TensorNode::size() const { +size_t TensorObj::size() const { size_t ret = 1; for (const auto &d : shape) ret *= d; return ret; } -void TensorNode::copyData(VType *dptr) { +void TensorObj::copyData(VType *dptr) { IT_ASSERT(data != nullptr); size_t sz = size(); #pragma omp parallel for @@ -48,7 +46,7 @@ void TensorNode::copyData(VType *dptr) { } } -void TensorNode::printData() const { +void TensorObj::printData() const { IT_ASSERT(data != nullptr); std::cout << "Tensor: " << guid << std::endl; auto numDims = shape.size(); @@ -75,7 +73,7 @@ void TensorNode::printData() const { } } -bool TensorNode::equalData(const Tensor &rhs) const { +bool TensorObj::equalData(const Tensor &rhs) const { IT_ASSERT(data != nullptr); IT_ASSERT(rhs->data != nullptr); if (shape != rhs->getDims()) diff --git a/src/core/tensor_base.cc b/src/core/tensor_base.cc index 72297ce0..84cad959 100644 --- a/src/core/tensor_base.cc +++ b/src/core/tensor_base.cc @@ -1,9 +1,9 @@ #include namespace infini { -TensorBaseNode::TensorBaseNode(int dim, DataType dtype) +TensorBaseObj::TensorBaseObj(int dim, DataType dtype) : dim(dim), dtype(dtype) {} -VType TensorBaseNode::getData(size_t offset) const { return data[offset]; } +VType TensorBaseObj::getData(size_t offset) const { return data[offset]; } }; // namespace infini \ No newline at end of file diff --git a/src/kerels/cpu/matmul.cc b/src/kerels/cpu/matmul.cc index 84fa53a3..527fd66c 100644 --- a/src/kerels/cpu/matmul.cc +++ b/src/kerels/cpu/matmul.cc @@ -5,7 +5,7 @@ namespace infini { template class NaiveMatmul : public Kernel { void compute(const Operator &_op, const PerfRecord &record) const override { - auto op = as(_op); + auto op = as(_op); T *A = reinterpret_cast(op->getInputs(0)->getDataPtr().get()); T *B = reinterpret_cast(op->getInputs(1)->getDataPtr().get()); T *C = reinterpret_cast(op->getOutput()->getDataPtr().get()); diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 9f15bc5c..20f60914 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -2,19 +2,17 @@ namespace infini { -vector MatmulNode::computeShape() const { return {{b, m, n}}; } - -MatmulNode::MatmulNode(Tensor A, Tensor B, Tensor C, bool transA, bool transB, - Tensor bias, ActType act) - : OperatorNode(OpType::Matmul, {A, B, bias}, {C}), transA(transA), +MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, + bool transB, Tensor bias, ActType act) + : OperatorObj(OpType::Matmul, {A, B, bias}, {C}), transA(transA), transB(transB), act(act), b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]), n(transB ? B->getDims()[1] : B->getDims()[2]), k(transA ? A->getDims()[1] : A->getDims()[2]) { - IT_ASSERT(checkValid(inputs)); + IT_ASSERT(checkValid(graph)); } -string MatmulNode::toString() const { +string MatmulObj::toString() const { std::ostringstream os; os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B") << ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid() @@ -23,34 +21,29 @@ string MatmulNode::toString() const { return os.str(); } -bool MatmulNode::checkValid(const TensorVec &inputs) const { +optional> MatmulObj::inferShape(const TensorVec &inputs) const { auto A = inputs[0], B = inputs[1]; // if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight) // return false; - IT_ASSERT(A->getDims().size() == 3 && B->getDims().size() == 3); - IT_ASSERT(A->getDims()[0] == B->getDims()[0]); - IT_ASSERT((transA ? A->getDims()[1] : A->getDims()[2]) == - (transB ? B->getDims()[2] : B->getDims()[1])); - // if (A->getDims().size() != 3 || B->getDims().size() != 3) { - // return false; - // } - // if (A->getDims()[0] != B->getDims()[0]) { - // return false; - // } - // if ((args.transA ? A->getDims()[1] : A->getDims()[2]) != - // (args.transB ? B->getDims()[2] : B->getDims()[1])) { - // return false; - // } - return true; + if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) + return {}; + if (!(A->getDims()[0] == B->getDims()[0])) + return {}; + if (!((transA ? A->getDims()[1] : A->getDims()[2]) == + (transB ? B->getDims()[2] : B->getDims()[1]))) + return {}; + int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]), + n(transB ? B->getDims()[1] : B->getDims()[2]); + return {{{b, m, n}}}; } -HashType MatmulNode::hashWithShape() const { - // TODO: use a real hash - return b + m + n + k + transA + transB + enum_to_underlying(act); +vector MatmulObj::getWorkloadVector() const { + return {enum_to_underlying(type), b, m, n, k, transA, transB, + enum_to_underlying(act)}; } -OpPerfKey MatmulNode::getOpPerfKey() const { - return OpPerfKey(hashWithShape(), type, - {b, m, n, k, transA, transB, enum_to_underlying(act)}); +vector MatmulObj::getOpAttrVector() const { + return {enum_to_underlying(type), transA, transB, enum_to_underlying(act)}; } + } // namespace infini \ No newline at end of file diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index b8a12333..21db982b 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -6,41 +6,41 @@ namespace infini { TEST(Graph, build_and_run) { - Graph g = make_ref(); + Graph g = make_ref(); Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32); g->dataMalloc(); i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); - g->addOp(make_ref(i0, w0, o0)); + g->addOpWithOutputs(i0, w0, o0); RunEngine(Device::CPU).run(g); // check answer - auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); + auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); ans->dataMalloc(); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}.data()); EXPECT_TRUE(o0->equalData(ans)); } TEST(Graph, perf_engine) { - Graph g = make_ref(); + Graph g = make_ref(); Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); - Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32); + auto matmul = g->addOp(i0, w0, nullptr); + g->dataMalloc(); i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); - g->addOp(make_ref(i0, w0, o0)); RunEngine(Device::CPU).run(g, true, true); double perfTime = RunEngine(Device::CPU).getPerfTime(g); // The example matmul takes 0.0036ms with one core EXPECT_GT(perfTime, 0); EXPECT_LT(perfTime, 0.01); // check answer - auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); + auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); ans->dataMalloc(); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}.data()); - EXPECT_TRUE(o0->equalData(ans)); + EXPECT_TRUE(matmul->getOutput()->equalData(ans)); } } // namespace infini \ No newline at end of file diff --git a/test/core/test_hash.cc b/test/core/test_hash.cc new file mode 100644 index 00000000..22955ec0 --- /dev/null +++ b/test/core/test_hash.cc @@ -0,0 +1,31 @@ +#include "core/graph.h" +#include "core/run_enigne.h" +#include "operators/matmul.h" +#include "test.h" + +namespace infini { + +TEST(Hash, OperatorHash) { + OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown); + { // build with addOpWithOutputs + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); + Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32); + auto matmul = g->addOpWithOutputs(i0, w0, o0); + key1 = matmul->getOpPerfKey(); + EXPECT_NE(key1.hash, 0); + EXPECT_GT(key1.attrs.size(), 5); + } + { // build with addOp + Graph g = make_ref(); + Tensor i0 = g->addTensor({2, 2, 3}, DataType::Int32); + Tensor w0 = g->addTensor({2, 3, 4}, DataType::Int32); + auto matmul = g->addOp(i0, w0, nullptr); + key2 = matmul->getOpPerfKey(); + EXPECT_NE(key2.hash, 0); + } + EXPECT_NE(key1.hash, key2.hash); +} + +} // namespace infini \ No newline at end of file