From 559be5866dd7c5d25d2f09d2c9f92c656437f0d6 Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Fri, 5 Aug 2022 12:50:34 +0800 Subject: [PATCH] Add: Matmul operator --- include/core/common.h | 20 +++ include/core/graph.h | 6 + include/core/object.h | 11 +- include/core/operator.h | 165 ++++++++++++++++++- include/core/ref.h | 10 +- include/core/tensor.h | 161 ++---------------- include/core/tensor_base.h | 329 +++++++++++++++++++++++++++++++++++++ src/core/graph.cc | 8 +- src/core/operator.cc | 53 +++++- src/core/tensor.cc | 25 ++- src/core/tensor_base.cc | 9 + test/core/test_graph.cc | 5 +- 12 files changed, 642 insertions(+), 160 deletions(-) create mode 100644 include/core/tensor_base.h create mode 100644 src/core/tensor_base.cc diff --git a/include/core/common.h b/include/core/common.h index f473db9e..801ce908 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -29,4 +29,24 @@ using std::vector; // Aliases using dtype = float; +// Utilities +#define _CAT(A, B) A##B +#define _SELECT(NAME, NUM) _CAT(NAME##_, NUM) +#define _GET_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ...) COUNT +#define _VA_SIZE(...) _GET_COUNT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) +#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__) + +// Assert +#define _IT_ASSERT_2(name, info) \ + (static_cast(name) \ + ? void(0) \ + : throw std::runtime_error( \ + std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ + "] Assertion failed (" + #name + "): " + #info)) +#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, ""); + +#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) +#define IT_TODO_HALT(...) IT_ASSERT(false, "Unimplemented") +#define IT_TODO_SKIP(...) puts("Unimplemented " __FILE__ ":" __LINE__) + } // namespace it diff --git a/include/core/graph.h b/include/core/graph.h index 6f2f6495..a3db32bb 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -23,6 +23,12 @@ class GraphNode : public Object { // TensorVec &getInputs(); // TensorVec &getOutputs(); + Tensor addTensor(Shape dim) { + Tensor tensor = make_ref(dim); + tensors.emplace_back(tensor); + return tensor; + } + void updateConnection(); // TODO diff --git a/include/core/object.h b/include/core/object.h index 2ce7db01..10a0d46d 100644 --- a/include/core/object.h +++ b/include/core/object.h @@ -1,5 +1,6 @@ #pragma once #include "core/common.h" +#include "ref.h" namespace it { @@ -42,4 +43,12 @@ inline std::ostream &operator<<(std::ostream &os, const Object &obj) { return os; } -} \ No newline at end of file +// Overload for Ref-wrapped Object +template > * = nullptr> +inline std::ostream &operator<<(std::ostream &os, const Ref &obj) { + os << obj->toString(); + return os; +} + +} // namespace it \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h index 467c5806..9565e78f 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -3,19 +3,176 @@ namespace it { +enum OpType { + Unknown = 0, + // linear + Conv = 100, + Matmul, + ConvTrans, + G2BMM, + GBMML, + Pad, + Slice, + Concat, + Split, + Transpose, + Extend, + MaxPool, + AvgPool, + Add, + Sub, + Mul, + Div, + Pow, + Gather, + ReduceMean, + Reshape, + Identity, + // element wise + BatchNorm = 200, + Softmax, + Activation, + Resize, + // + MemBound = 300, +}; + +class OpRegistry { + public: + std::string getOpName(OpType opType) { +#define FOP(op) \ + case op: \ + return #op + + switch (opType) { + FOP(Unknown); + // linear + FOP(Conv); + FOP(Matmul); + FOP(ConvTrans); + FOP(G2BMM); + FOP(GBMML); + FOP(Pad); + FOP(Slice); + FOP(Concat); + FOP(Split); + FOP(Transpose); + FOP(Extend); + FOP(MaxPool); + FOP(AvgPool); + FOP(Add); + FOP(Sub); + FOP(Mul); + FOP(Div); + FOP(Pow); + FOP(Gather); + FOP(ReduceMean); + FOP(Reshape); + FOP(Identity); + // element wise + FOP(BatchNorm); + FOP(Softmax); + FOP(Activation); + // + FOP(MemBound); + default: + IT_ASSERT(false); + break; + } +#undef FOP + } +}; + +enum ActType { + None, + Relu, + Sigmoid, + Tanh, +}; + class OperatorNode : public Object { + public: protected: - // OpType type; + OpType type; TensorVec inputs; TensorVec outputs; // vector> predecessors; // vector> successors; + public: OperatorNode(TensorVec inputs, TensorVec outputs) : inputs(inputs), outputs(outputs) {} - string toString() const override; - // Operator(TensorVec inputs) : inputs(inputs) {} + virtual vector computeShape() const = 0; - virtual ~OperatorNode() {} + public: // check Op type + bool isLinearOp() const { return type >= 100 && type < 200; } + bool isElementWiseOp() const { return type >= 200 && type < 300; } + bool isSplitOp() const { return type == Split; } + bool isConcatOp() const { return type == Concat; } + bool isComputeOp() const { + return type == Conv || type == Matmul || type == ConvTrans || + type == G2BMM || type == GBMML; + } + bool isTransposeOp() const { return type == Transpose; } + + bool isReshapeOp() const { return type == Reshape; } + + bool isMemBoundOp() const { + return type == MemBound || type == Activation || type == Transpose; + } + + public: // getter and setter + // TensorVec getInputs() { return inputs; } + const TensorVec &getInputs() const { return inputs; } + // TensorVec getOutputs() { return outputs; } + const TensorVec &getOutputs() const { return outputs; } + Tensor getInputs(size_t i) { return inputs.at(i); } + Tensor getOutput() const { + IT_ASSERT(outputs.size() == 1, "Unimplemented"); + return outputs[0]; + } + + virtual int numInputs() const = 0; + virtual int numOutputs() const = 0; }; + +class MatmulNode : public OperatorNode { + public: + struct MatmulArgs { + int b, m, n, k; + // PET assume a row-major tensor layout. transA=false means default + // dims, true means A should be transposed before matmul. This is in + // oppsite to column-major BLAS. + bool transA, transB; + ActType act; + }; + + private: + MatmulArgs args; + + public: + MatmulNode(Tensor A, Tensor B, Tensor C, bool transA = false, + bool transB = false, Tensor bias = nullptr, ActType act = None); + + std::string toString() const override; + vector computeShape() const override; + + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + + Tensor getBias() const { return inputs[2]; } + void setAct(ActType act) { this->args.act = act; } + ActType getAct() const { return args.act; } + bool getTransA() const { return args.transA; } + bool getTransB() const { return args.transB; } + + MatmulArgs getArgs() const { return args; } + + private: + // Q: whether to check the output? Since we can build an Op first and then + // assure output. + // Fix 1: make shape inference a static method. + bool checkValid(const TensorVec &inputs) const; +}; + } // namespace it \ No newline at end of file diff --git a/include/core/ref.h b/include/core/ref.h index 7799d02a..34546e02 100644 --- a/include/core/ref.h +++ b/include/core/ref.h @@ -1,5 +1,4 @@ #pragma once -#include "common.h" #include // hash #include #include @@ -24,4 +23,13 @@ Ref as(const Ref &ref) { return std::dynamic_pointer_cast(ref); } +template +std::vector> get_wref_vec(const std::vector> &vec) { + std::vector> wref_vec; + wref_vec.reserve(vec.size()); + for (const auto &ref : vec) + wref_vec.emplace_back(ref); + return wref_vec; +} + } // namespace it \ No newline at end of file diff --git a/include/core/tensor.h b/include/core/tensor.h index a7e49fcb..569d7294 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,129 +1,33 @@ #pragma once -#include "core/object.h" -#include "core/ref.h" +#include "core/tensor_base.h" namespace it { -// class Tensor; -class TensorBaseNode; -class OperatorNode; -class GraphNode; - -using TensorBase = Ref; -using Operator = Ref; -using Graph = Ref; - -using TensorVec = vector; -using OpVec = vector; - -// using TensorMap = std::map; -// using OpMap = std::map; -using VType = uint32_t; -// using SplittingPoints = std::vector>; - -class TensorBaseNode : public Object { - public: - enum DataType { - Float32, - Int32, - }; - - // enum TensorType { - // Input, - // Weight, - // Invalid, - // NotCounted, - // }; - - // // TODO: is more compute state needed? - // enum ComputeState { - // NotComputed, - // // Allocated, - // // Initialized, - // // ComputedPartial, - // ComputedFull, - // }; - +// TODO: how to deal with this +using ShapeElem = int; +using Shape = vector; +class TensorNode : public TensorBaseNode { private: - int hid; - // uint64_t hash; - // Shape shape; - int dim; - - vector> inputOf; - WRef outputOf; - Ref data; - DataType dtype; - // ComputeState computed; - // static int random_seed[256 * 16]; - // static bool random_inited; + Shape shape; public: - // Tensor(TensorType type = Input, DataType dtype = Float32) - // : guid(generateGuid()), hash(generateHash()), outputOf(nullptr), - // data(nullptr), dtype(dtype), type(type), computed(NotComputed) {} - // Tensor(const Dim &dims, TensorType type = Input, DataType dtype = - // Float32) - // : guid(generateGuid()), hash(generateHash()), dims(dims), - // outputOf(nullptr), data(nullptr), dtype(dtype), type(type), - // computed(NotComputed) { - // itInit(); - // } - // Tensor(const Tensor &rhs) : Tensor(rhs.dims, rhs.type, rhs.dtype) { - // outputOf = nullptr; - // data = nullptr; - // hash = rhs.hash; - // dimPenalty = rhs.dimPenalty; - // itInit(); - // } - // Tensor(VType scalar, TensorType type = Weight, DataType dtype = Float32) - // : guid(generateGuid()), hash(generateHash()), outputOf(nullptr), - // data(nullptr), dtype(dtype), type(type), computed(ComputedFull) { - // assert(size() == 1); - // dataMalloc(); - // data[0] = scalar; - // } - virtual ~TensorBaseNode() {} + TensorNode(const Shape &shape, DataType dtype = DataType::Float32); + virtual ~TensorNode() {} string toString() const override; - // // inputOf and outputOf will not be cloned - // Tensor *clone() { - // Tensor *t = new Tensor(*this); - // return t; - // } + int size(); - // void clone(Tensor *t) { - // dims = t->dims; - // dtype = t->dtype; - // type = t->type; - // hash = t->hash; - // dimPenalty = t->dimPenalty; - // } + void dataMalloc(size_t size) { + IT_ASSERT(data == nullptr); + data = make_ref>(size); + } - DataType getDType() const { return dtype; } + Shape getDims() const { return shape; } - // uint64_t getHash() const { return hash; } - - // void setInputOf(const OpVec &ops) { - // inputOf.clear(); - // for (const auto &op : ops) - // inputOf.emplace_back(op); - // } - // void addInputOf(Operator op) { inputOf.emplace_back(op); } - // void setOutputOf(Operator op) { outputOf = op; } - - // const OpVec &getInputOf() { return inputOf; } - // Operator *getOutputOf() { return outputOf; } - // std::pair getOutputOfWithIndex(); - - // bool dataMalloc() { - // if (data == nullptr) - // data = new VType[size()]; - // return data != nullptr; - // } - - // const Dim &getDims() const { return dims; } - // void setDims(const Dim &dms) { dims = dms; } + size_t getOffset(const Shape &ds) const; + using TensorBaseNode::getData; + VType getData(const Shape &pos) const; + // void setDims(const Dim &dms) { dims = dms; } // bool dataRand(int seed = 0) { // if (data == nullptr) @@ -177,35 +81,6 @@ class TensorBaseNode : public Object { // VType getScalar() { return data == nullptr ? 0 : data[0]; } - // VType getData(const Dim &ds) { - // assert(data != nullptr); - // auto offset = getOffset(ds); - // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; - // } - - // VType getData(size_t pos) { - // assert(data != nullptr); - // assert(pos < size()); - // return data[pos]; - // } - - // VType *getDataPtr() const { return data; } - - // size_t getOffset(const Dim &ds) { - // auto nDim = ds.size(); - // assert(dims.size() == nDim); - // if (ds.empty()) - // return 0; - // for (size_t i = 0; i < nDim; ++i) - // if (ds[i] < 0 || ds[i] >= dims[i]) - // return (size_t)-1; - // size_t idx = ds[0]; - // size_t dm = 0; - // while (++dm < nDim) - // idx = idx * dims[dm] + ds[dm]; - // return idx; - // } - // VType getBroadcastData(const Dim &ds) { // assert(data != nullptr); // auto offset = getBroadcastOffset(ds); diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h new file mode 100644 index 00000000..368c3a65 --- /dev/null +++ b/include/core/tensor_base.h @@ -0,0 +1,329 @@ +#pragma once +#include "core/object.h" +#include "core/ref.h" + +namespace it { + +// class Tensor; +class TensorBaseNode; +class TensorNode; +class OperatorNode; +class GraphNode; + +using TensorBase = Ref; +using Tensor = Ref; +using Operator = Ref; +using Graph = Ref; + +using TensorVec = vector; +using OpVec = vector; + +using VType = uint32_t; + +class TensorBaseNode : public Object { + public: + enum DataType { + Float32, + Int32, + }; + + // enum TensorType { + // Input, + // Weight, + // Invalid, + // NotCounted, + // }; + + // // TODO: is more compute state needed? + // enum ComputeState { + // NotComputed, + // // Allocated, + // // Initialized, + // // ComputedPartial, + // ComputedFull, + // }; + + protected: + int dim; + + DataType dtype; + vector> inputOf; + WRef outputOf; + Ref> data; + // ComputeState computed; + // static int random_seed[256 * 16]; + // static bool random_inited; + + public: + TensorBaseNode(int dim, DataType dtype); + virtual ~TensorBaseNode() {} + + // Ref> getDataPtr() const { return data; } + VType getData(size_t offset) const; + + DataType getDType() const { return dtype; } + + // uint64_t getHash() const { return hash; } + + // void setInputOf(const OpVec &ops) { + // inputOf.clear(); + // for (const auto &op : ops) + // inputOf.emplace_back(op); + // } + // void addInputOf(Operator op) { inputOf.emplace_back(op); } + // void setOutputOf(Operator op) { outputOf = op; } + + // const OpVec &getInputOf() { return inputOf; } + // Operator *getOutputOf() { return outputOf; } + // std::pair getOutputOfWithIndex(); + + // bool dataMalloc() { + // if (data == nullptr) + // data = new VType[size()]; + // return data != nullptr; + // } + + // const Dim &getDims() const { return dims; } + // void setDims(const Dim &dms) { dims = dms; } + + // bool dataRand(int seed = 0) { + // if (data == nullptr) + // data = new VType[size()]; + // if (!random_inited) + // initFastrand(); + // // srand(seed); + // // faster rand generator; parallel + // size_t iEnd = size(); + // // std::cerr << "Init beginned " << std::endl; + // #pragma omp parallel for + // for (size_t i = 0; i < iEnd; ++i) + // data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) % + // 10000; + // // std::cerr << "Init finished" << std::endl; + // computed = ComputedFull; + // return true; + // } + + // bool setData(VType *dptr) { + // if (dptr == nullptr) + // return false; + // auto sz = size(); + // #pragma omp parallel for + // for (size_t i = 0; i < sz; ++i) + // data[i] = dptr[i]; + // computed = ComputedFull; + // return true; + // } + + // bool setScalar(VType val) { + // if (data == nullptr || !dims.empty()) + // return false; + // data[0] = val; + // return true; + // } + + // bool setData(const Dim &ds, VType val) { + // if (data == nullptr || ds.size() != dims.size()) + // return false; + // data[getOffset(ds)] = val; + // return true; + // } + + // bool setData(size_t pos, VType val) { + // if (data == nullptr || pos >= size()) + // return false; + // data[pos] = val; + // return true; + // } + + // VType getScalar() { return data == nullptr ? 0 : data[0]; } + + // VType getData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getData(size_t pos) { + // assert(data != nullptr); + // assert(pos < size()); + // return data[pos]; + // } + + // VType *getDataPtr() const { return data; } + + // size_t getOffset(const Dim &ds) { + // auto nDim = ds.size(); + // assert(dims.size() == nDim); + // if (ds.empty()) + // return 0; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[i] < 0 || ds[i] >= dims[i]) + // return (size_t)-1; + // size_t idx = ds[0]; + // size_t dm = 0; + // while (++dm < nDim) + // idx = idx * dims[dm] + ds[dm]; + // return idx; + // } + + // VType getBroadcastData(const Dim &ds) { + // assert(data != nullptr); + // auto offset = getBroadcastOffset(ds); + // return offset == (size_t)-1 ? 0 : data[getOffset(ds)]; + // } + + // VType getBroadcastData(size_t pos) { + // assert(data != nullptr); + // return data[pos % size()]; + // } + + // size_t getBroadcastOffset(const Dim &ds) { + // assert(ds.size() >= dims.size()); + // auto nDim = dims.size(); + // auto nBroadcastDim = ds.size() - nDim; + // for (size_t i = 0; i < nDim; ++i) + // if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >= + // dims[i]) + // return (size_t)-1; + // size_t idx = 0; + // for (size_t i = 0; i < nDim; ++i) + // idx = idx * dims[i] + ds[nBroadcastDim + i]; + // return idx; + // } + + // void itInit() { it = Dim(dims.size(), 0); } + + // void itReset() { + // itInit(); + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // it[i] = 0; + // } + + // bool itValid() { + // if (it.size() != dims.size()) + // return false; + // for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i) + // if (it[i] >= dims[i]) + // return false; + // return true; + // } + + // const Dim &itGet() { return it; } + + // void itNext() { + // auto p = it.size() - 1; + // it[p] += 1; + // while (p >= 1) { + // if (it[p] == dims[p]) { + // it[p] = 0; + // it[--p] += 1; + // } else + // break; + // } + // } + + // size_t size() const { + // size_t sz = 1; + // auto dm = dims.size(); + // while (dm > 0) + // sz *= dims[--dm]; + // return sz; + // } + + // TensorType getType() const { return type; } + // void setType(TensorType ty) { type = ty; } + + // void print() { + // if (type == Invalid) { + // std::cout << "Invalid tensor" << std::endl; + // return; + // } + + // if (data == nullptr || dims.size() == 0) { + // std::cout << "Empty tensor" << std::endl; + // return; + // } + + // // TODO: can be uncommented after tensor's compute type is + // correctly set if (computed == NotComputed) { + // std::cout << "Uncomputed tensor" << std::endl; + // return; + // } + + // std::cout << "Tensor: " << guid << std::endl; + // auto numDims = dims.size(); + // auto dimSzVec = std::vector(numDims, 1); + // dimSzVec[numDims - 1] = dims[numDims - 1]; + // for (int i = numDims - 1; i != 0; --i) + // dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1]; + // for (size_t i = 0, iEnd = size(); i < iEnd; ++i) { + // for (size_t j = 0; j < numDims; ++j) { + // if (i % dimSzVec[j] == 0) { + // std::cout << "["; + // } + // } + // std::cout << data[i]; + // for (size_t j = 0; j < numDims; ++j) { + // if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) { + // std::cout << "]"; + // } + // } + // if (i != size() - 1) + // std::cout << ", "; + // if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - + // 1) + // std::cout << std::endl; + // } + // } + + // static inline void initFastrand() { + // assert(omp_get_max_threads() <= 256); + // // srand(0); // constant seed for test + // // align random_seed to avoid false sharing + // for (int i = 0; i < 256 * 16; ++i) { + // // random_seed[i] = rand(); + // // constant random seed for test + // random_seed[i] = i; + // } + // random_inited = true; + // } + + // static inline int fastrand(int &g_seed) { + // g_seed = (214013 * g_seed + 2531011); + // return (g_seed >> 16) & 0x7FFF; + // } + + // std::vector> const *getSplittingPoints() const { + // assert(!splittingPoints.empty()); + // return &splittingPoints; + // } + + // bool setSplittingPoints(std::vector> value) { + // assert(!value.empty()); + // splittingPoints = value; + // return true; + // } + + // void printSplittingPoints() { + // if (splittingPoints.empty()) + // printf("Empty SplittingPoints"); + // else { + // printf("["); + // for (auto &vs : splittingPoints) { + // printf("["); + // for (auto v : vs) + // printf("%2d,", v); + // printf("],"); + // } + // printf("]"); + // } + // } + + // void initSplittingPoints() { + // splittingPoints.resize(getDims().size()); } + + // void printShape(); +}; + +} // namespace it \ No newline at end of file diff --git a/src/core/graph.cc b/src/core/graph.cc index bf78d5a2..5b9bbe23 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -2,13 +2,13 @@ namespace it { -void GraphNode::updateConnection() { - // TODO -} +void GraphNode::updateConnection() { IT_TODO_HALT(); } string GraphNode::toString() const { std::ostringstream oss; - oss << "GraphNode: "; + oss << "GraphNode operators:\n"; + for (const auto &op : ops) + oss << op << "\n"; return oss.str(); } diff --git a/src/core/operator.cc b/src/core/operator.cc index 5817d55c..003c24a8 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -2,10 +2,55 @@ namespace it { -string OperatorNode::toString() const { - std::ostringstream oss; - oss << "Operator: "; - return oss.str(); +vector MatmulNode::computeShape() const { + Shape ret{args.b, args.m, args.n}; + return {ret}; +} + +MatmulNode::MatmulNode(Tensor A, Tensor B, Tensor C, bool transA, bool transB, + Tensor bias, ActType act) + : OperatorNode({A, B, bias}, {C}), args{.b = A->getDims()[0], + .m = transA ? A->getDims()[2] + : A->getDims()[1], + .n = transB ? B->getDims()[1] + : B->getDims()[2], + .k = transA ? A->getDims()[1] + : A->getDims()[2], + .transA = transA, + .transB = transB, + .act = act} { + IT_ASSERT(checkValid(inputs)); +} + +string MatmulNode::toString() const { + std::ostringstream os; + MatmulArgs args = getArgs(); + os << "Matmul([" << (args.transA ? "A^T" : "A") << "," + << (args.transB ? "B^T" : "B") << ",act=" << (int)args.act + << "],A=" << inputs[0]->getGuid() << ",B=" << inputs[1]->getGuid() + << ",C=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +bool MatmulNode::checkValid(const TensorVec &inputs) const { + auto A = inputs[0], B = inputs[1]; + // if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight) + // return false; + IT_ASSERT(A->getDims().size() == 3 && B->getDims().size() == 3); + IT_ASSERT(A->getDims()[0] == B->getDims()[0]); + IT_ASSERT((args.transA ? A->getDims()[1] : A->getDims()[2]) == + (args.transB ? B->getDims()[2] : B->getDims()[1])); + // if (A->getDims().size() != 3 || B->getDims().size() != 3) { + // return false; + // } + // if (A->getDims()[0] != B->getDims()[0]) { + // return false; + // } + // if ((args.transA ? A->getDims()[1] : A->getDims()[2]) != + // (args.transB ? B->getDims()[2] : B->getDims()[1])) { + // return false; + // } + return true; } } // namespace it \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 60710c34..9f20c224 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -1,8 +1,29 @@ #include namespace it { -string TensorBaseNode::toString() const { - return "TensorBaseNode " + std::to_string(guid); +TensorNode::TensorNode(const Shape &shape, DataType dtype) + : TensorBaseNode(shape.size(), dtype), shape(shape) {} + +VType TensorNode::getData(const Shape &pos) const { + return getData(getOffset(pos)); +} + +string TensorNode::toString() const { + return "TensorNode " + std::to_string(guid); +} + +size_t TensorNode::getOffset(const Shape &pos) const { + auto nDim = pos.size(); + IT_ASSERT(shape.size() == nDim); + if (pos.empty()) + return 0; + for (size_t i = 0; i < nDim; ++i) + IT_ASSERT(pos[i] < 0 || pos[i] >= shape[i]); + size_t idx = pos[0]; + size_t dm = 0; + while (++dm < nDim) + idx = idx * shape[dm] + pos[dm]; + return idx; } }; // namespace it \ No newline at end of file diff --git a/src/core/tensor_base.cc b/src/core/tensor_base.cc new file mode 100644 index 00000000..29bb40d3 --- /dev/null +++ b/src/core/tensor_base.cc @@ -0,0 +1,9 @@ +#include +namespace it { + +TensorBaseNode::TensorBaseNode(int dim, DataType dtype) + : dim(dim), dtype(dtype) {} + +VType TensorBaseNode::getData(size_t offset) const { return data->at(offset); } + +}; // namespace it \ No newline at end of file diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 41587725..54849783 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -5,7 +5,10 @@ namespace it { TEST(Graph, build) { Graph g = make_ref(); - g->addOp(make_ref(TensorVec{}, TensorVec{})); + Tensor i0 = g->addTensor({1, 2, 3}); + Tensor w0 = g->addTensor({1, 3, 4}); + Tensor o0 = g->addTensor({1, 2, 4}); + g->addOp(make_ref(i0, w0, o0)); g->print(); }