Add: Matmul operator

This commit is contained in:
Liyan Zheng 2022-08-05 12:50:34 +08:00
parent e6101b0336
commit 559be5866d
12 changed files with 642 additions and 160 deletions

View File

@ -29,4 +29,24 @@ using std::vector;
// Aliases // Aliases
using dtype = float; using dtype = float;
// Utilities
#define _CAT(A, B) A##B
#define _SELECT(NAME, NUM) _CAT(NAME##_, NUM)
#define _GET_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ...) COUNT
#define _VA_SIZE(...) _GET_COUNT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1)
#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__)
// Assert
#define _IT_ASSERT_2(name, info) \
(static_cast<bool>(name) \
? void(0) \
: throw std::runtime_error( \
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
"] Assertion failed (" + #name + "): " + #info))
#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, "");
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
#define IT_TODO_HALT(...) IT_ASSERT(false, "Unimplemented")
#define IT_TODO_SKIP(...) puts("Unimplemented " __FILE__ ":" __LINE__)
} // namespace it } // namespace it

View File

@ -23,6 +23,12 @@ class GraphNode : public Object {
// TensorVec &getInputs(); // TensorVec &getInputs();
// TensorVec &getOutputs(); // TensorVec &getOutputs();
Tensor addTensor(Shape dim) {
Tensor tensor = make_ref<TensorNode>(dim);
tensors.emplace_back(tensor);
return tensor;
}
void updateConnection(); void updateConnection();
// TODO // TODO

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "core/common.h" #include "core/common.h"
#include "ref.h"
namespace it { namespace it {
@ -42,4 +43,12 @@ inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
return os; return os;
} }
// Overload for Ref-wrapped Object
template <typename T,
typename std::enable_if_t<std::is_base_of_v<Object, T>> * = nullptr>
inline std::ostream &operator<<(std::ostream &os, const Ref<T> &obj) {
os << obj->toString();
return os;
} }
} // namespace it

View File

@ -3,19 +3,176 @@
namespace it { namespace it {
enum OpType {
Unknown = 0,
// linear
Conv = 100,
Matmul,
ConvTrans,
G2BMM,
GBMML,
Pad,
Slice,
Concat,
Split,
Transpose,
Extend,
MaxPool,
AvgPool,
Add,
Sub,
Mul,
Div,
Pow,
Gather,
ReduceMean,
Reshape,
Identity,
// element wise
BatchNorm = 200,
Softmax,
Activation,
Resize,
//
MemBound = 300,
};
class OpRegistry {
public:
std::string getOpName(OpType opType) {
#define FOP(op) \
case op: \
return #op
switch (opType) {
FOP(Unknown);
// linear
FOP(Conv);
FOP(Matmul);
FOP(ConvTrans);
FOP(G2BMM);
FOP(GBMML);
FOP(Pad);
FOP(Slice);
FOP(Concat);
FOP(Split);
FOP(Transpose);
FOP(Extend);
FOP(MaxPool);
FOP(AvgPool);
FOP(Add);
FOP(Sub);
FOP(Mul);
FOP(Div);
FOP(Pow);
FOP(Gather);
FOP(ReduceMean);
FOP(Reshape);
FOP(Identity);
// element wise
FOP(BatchNorm);
FOP(Softmax);
FOP(Activation);
//
FOP(MemBound);
default:
IT_ASSERT(false);
break;
}
#undef FOP
}
};
enum ActType {
None,
Relu,
Sigmoid,
Tanh,
};
class OperatorNode : public Object { class OperatorNode : public Object {
public:
protected: protected:
// OpType type; OpType type;
TensorVec inputs; TensorVec inputs;
TensorVec outputs; TensorVec outputs;
// vector<WRef<Operator>> predecessors; // vector<WRef<Operator>> predecessors;
// vector<WRef<Operator>> successors; // vector<WRef<Operator>> successors;
public: public:
OperatorNode(TensorVec inputs, TensorVec outputs) OperatorNode(TensorVec inputs, TensorVec outputs)
: inputs(inputs), outputs(outputs) {} : inputs(inputs), outputs(outputs) {}
string toString() const override; virtual vector<Shape> computeShape() const = 0;
// Operator(TensorVec inputs) : inputs(inputs) {}
virtual ~OperatorNode() {} public: // check Op type
bool isLinearOp() const { return type >= 100 && type < 200; }
bool isElementWiseOp() const { return type >= 200 && type < 300; }
bool isSplitOp() const { return type == Split; }
bool isConcatOp() const { return type == Concat; }
bool isComputeOp() const {
return type == Conv || type == Matmul || type == ConvTrans ||
type == G2BMM || type == GBMML;
}
bool isTransposeOp() const { return type == Transpose; }
bool isReshapeOp() const { return type == Reshape; }
bool isMemBoundOp() const {
return type == MemBound || type == Activation || type == Transpose;
}
public: // getter and setter
// TensorVec getInputs() { return inputs; }
const TensorVec &getInputs() const { return inputs; }
// TensorVec getOutputs() { return outputs; }
const TensorVec &getOutputs() const { return outputs; }
Tensor getInputs(size_t i) { return inputs.at(i); }
Tensor getOutput() const {
IT_ASSERT(outputs.size() == 1, "Unimplemented");
return outputs[0];
}
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;
}; };
class MatmulNode : public OperatorNode {
public:
struct MatmulArgs {
int b, m, n, k;
// PET assume a row-major tensor layout. transA=false means default
// dims, true means A should be transposed before matmul. This is in
// oppsite to column-major BLAS.
bool transA, transB;
ActType act;
};
private:
MatmulArgs args;
public:
MatmulNode(Tensor A, Tensor B, Tensor C, bool transA = false,
bool transB = false, Tensor bias = nullptr, ActType act = None);
std::string toString() const override;
vector<Shape> computeShape() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
Tensor getBias() const { return inputs[2]; }
void setAct(ActType act) { this->args.act = act; }
ActType getAct() const { return args.act; }
bool getTransA() const { return args.transA; }
bool getTransB() const { return args.transB; }
MatmulArgs getArgs() const { return args; }
private:
// Q: whether to check the output? Since we can build an Op first and then
// assure output.
// Fix 1: make shape inference a static method.
bool checkValid(const TensorVec &inputs) const;
};
} // namespace it } // namespace it

View File

@ -1,5 +1,4 @@
#pragma once #pragma once
#include "common.h"
#include <functional> // hash #include <functional> // hash
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
@ -24,4 +23,13 @@ Ref<T> as(const Ref<U> &ref) {
return std::dynamic_pointer_cast<T>(ref); return std::dynamic_pointer_cast<T>(ref);
} }
template <typename T>
std::vector<WRef<T>> get_wref_vec(const std::vector<Ref<T>> &vec) {
std::vector<WRef<T>> wref_vec;
wref_vec.reserve(vec.size());
for (const auto &ref : vec)
wref_vec.emplace_back(ref);
return wref_vec;
}
} // namespace it } // namespace it

View File

@ -1,128 +1,32 @@
#pragma once #pragma once
#include "core/object.h" #include "core/tensor_base.h"
#include "core/ref.h"
namespace it { namespace it {
// class Tensor; // TODO: how to deal with this
class TensorBaseNode; using ShapeElem = int;
class OperatorNode; using Shape = vector<ShapeElem>;
class GraphNode; class TensorNode : public TensorBaseNode {
using TensorBase = Ref<TensorBaseNode>;
using Operator = Ref<OperatorNode>;
using Graph = Ref<GraphNode>;
using TensorVec = vector<TensorBase>;
using OpVec = vector<Operator>;
// using TensorMap = std::map<size_t, Tensor *>;
// using OpMap = std::map<size_t, Operator *>;
using VType = uint32_t;
// using SplittingPoints = std::vector<std::vector<int>>;
class TensorBaseNode : public Object {
public:
enum DataType {
Float32,
Int32,
};
// enum TensorType {
// Input,
// Weight,
// Invalid,
// NotCounted,
// };
// // TODO: is more compute state needed?
// enum ComputeState {
// NotComputed,
// // Allocated,
// // Initialized,
// // ComputedPartial,
// ComputedFull,
// };
private: private:
int hid; Shape shape;
// uint64_t hash;
// Shape shape;
int dim;
vector<WRef<TensorBase>> inputOf;
WRef<TensorBase> outputOf;
Ref<VType> data;
DataType dtype;
// ComputeState computed;
// static int random_seed[256 * 16];
// static bool random_inited;
public: public:
// Tensor(TensorType type = Input, DataType dtype = Float32) TensorNode(const Shape &shape, DataType dtype = DataType::Float32);
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr), virtual ~TensorNode() {}
// data(nullptr), dtype(dtype), type(type), computed(NotComputed) {}
// Tensor(const Dim &dims, TensorType type = Input, DataType dtype =
// Float32)
// : guid(generateGuid()), hash(generateHash()), dims(dims),
// outputOf(nullptr), data(nullptr), dtype(dtype), type(type),
// computed(NotComputed) {
// itInit();
// }
// Tensor(const Tensor &rhs) : Tensor(rhs.dims, rhs.type, rhs.dtype) {
// outputOf = nullptr;
// data = nullptr;
// hash = rhs.hash;
// dimPenalty = rhs.dimPenalty;
// itInit();
// }
// Tensor(VType scalar, TensorType type = Weight, DataType dtype = Float32)
// : guid(generateGuid()), hash(generateHash()), outputOf(nullptr),
// data(nullptr), dtype(dtype), type(type), computed(ComputedFull) {
// assert(size() == 1);
// dataMalloc();
// data[0] = scalar;
// }
virtual ~TensorBaseNode() {}
string toString() const override; string toString() const override;
// // inputOf and outputOf will not be cloned int size();
// Tensor *clone() {
// Tensor *t = new Tensor(*this);
// return t;
// }
// void clone(Tensor *t) { void dataMalloc(size_t size) {
// dims = t->dims; IT_ASSERT(data == nullptr);
// dtype = t->dtype; data = make_ref<vector<VType>>(size);
// type = t->type; }
// hash = t->hash;
// dimPenalty = t->dimPenalty;
// }
DataType getDType() const { return dtype; } Shape getDims() const { return shape; }
// uint64_t getHash() const { return hash; } size_t getOffset(const Shape &ds) const;
using TensorBaseNode::getData;
// void setInputOf(const OpVec &ops) { VType getData(const Shape &pos) const;
// inputOf.clear();
// for (const auto &op : ops)
// inputOf.emplace_back(op);
// }
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
// void setOutputOf(Operator op) { outputOf = op; }
// const OpVec &getInputOf() { return inputOf; }
// Operator *getOutputOf() { return outputOf; }
// std::pair<Operator *, int> getOutputOfWithIndex();
// bool dataMalloc() {
// if (data == nullptr)
// data = new VType[size()];
// return data != nullptr;
// }
// const Dim &getDims() const { return dims; }
// void setDims(const Dim &dms) { dims = dms; } // void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) { // bool dataRand(int seed = 0) {
@ -177,35 +81,6 @@ class TensorBaseNode : public Object {
// VType getScalar() { return data == nullptr ? 0 : data[0]; } // VType getScalar() { return data == nullptr ? 0 : data[0]; }
// VType getData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getData(size_t pos) {
// assert(data != nullptr);
// assert(pos < size());
// return data[pos];
// }
// VType *getDataPtr() const { return data; }
// size_t getOffset(const Dim &ds) {
// auto nDim = ds.size();
// assert(dims.size() == nDim);
// if (ds.empty())
// return 0;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[i] < 0 || ds[i] >= dims[i])
// return (size_t)-1;
// size_t idx = ds[0];
// size_t dm = 0;
// while (++dm < nDim)
// idx = idx * dims[dm] + ds[dm];
// return idx;
// }
// VType getBroadcastData(const Dim &ds) { // VType getBroadcastData(const Dim &ds) {
// assert(data != nullptr); // assert(data != nullptr);
// auto offset = getBroadcastOffset(ds); // auto offset = getBroadcastOffset(ds);

329
include/core/tensor_base.h Normal file
View File

@ -0,0 +1,329 @@
#pragma once
#include "core/object.h"
#include "core/ref.h"
namespace it {
// class Tensor;
class TensorBaseNode;
class TensorNode;
class OperatorNode;
class GraphNode;
using TensorBase = Ref<TensorBaseNode>;
using Tensor = Ref<TensorNode>;
using Operator = Ref<OperatorNode>;
using Graph = Ref<GraphNode>;
using TensorVec = vector<Tensor>;
using OpVec = vector<Operator>;
using VType = uint32_t;
class TensorBaseNode : public Object {
public:
enum DataType {
Float32,
Int32,
};
// enum TensorType {
// Input,
// Weight,
// Invalid,
// NotCounted,
// };
// // TODO: is more compute state needed?
// enum ComputeState {
// NotComputed,
// // Allocated,
// // Initialized,
// // ComputedPartial,
// ComputedFull,
// };
protected:
int dim;
DataType dtype;
vector<WRef<TensorBaseNode>> inputOf;
WRef<TensorBaseNode> outputOf;
Ref<vector<VType>> data;
// ComputeState computed;
// static int random_seed[256 * 16];
// static bool random_inited;
public:
TensorBaseNode(int dim, DataType dtype);
virtual ~TensorBaseNode() {}
// Ref<vector<VType>> getDataPtr() const { return data; }
VType getData(size_t offset) const;
DataType getDType() const { return dtype; }
// uint64_t getHash() const { return hash; }
// void setInputOf(const OpVec &ops) {
// inputOf.clear();
// for (const auto &op : ops)
// inputOf.emplace_back(op);
// }
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
// void setOutputOf(Operator op) { outputOf = op; }
// const OpVec &getInputOf() { return inputOf; }
// Operator *getOutputOf() { return outputOf; }
// std::pair<Operator *, int> getOutputOfWithIndex();
// bool dataMalloc() {
// if (data == nullptr)
// data = new VType[size()];
// return data != nullptr;
// }
// const Dim &getDims() const { return dims; }
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {
// if (data == nullptr)
// data = new VType[size()];
// if (!random_inited)
// initFastrand();
// // srand(seed);
// // faster rand generator; parallel
// size_t iEnd = size();
// // std::cerr << "Init beginned " << std::endl;
// #pragma omp parallel for
// for (size_t i = 0; i < iEnd; ++i)
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
// 10000;
// // std::cerr << "Init finished" << std::endl;
// computed = ComputedFull;
// return true;
// }
// bool setData(VType *dptr) {
// if (dptr == nullptr)
// return false;
// auto sz = size();
// #pragma omp parallel for
// for (size_t i = 0; i < sz; ++i)
// data[i] = dptr[i];
// computed = ComputedFull;
// return true;
// }
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
// data[0] = val;
// return true;
// }
// bool setData(const Dim &ds, VType val) {
// if (data == nullptr || ds.size() != dims.size())
// return false;
// data[getOffset(ds)] = val;
// return true;
// }
// bool setData(size_t pos, VType val) {
// if (data == nullptr || pos >= size())
// return false;
// data[pos] = val;
// return true;
// }
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
// VType getData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getData(size_t pos) {
// assert(data != nullptr);
// assert(pos < size());
// return data[pos];
// }
// VType *getDataPtr() const { return data; }
// size_t getOffset(const Dim &ds) {
// auto nDim = ds.size();
// assert(dims.size() == nDim);
// if (ds.empty())
// return 0;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[i] < 0 || ds[i] >= dims[i])
// return (size_t)-1;
// size_t idx = ds[0];
// size_t dm = 0;
// while (++dm < nDim)
// idx = idx * dims[dm] + ds[dm];
// return idx;
// }
// VType getBroadcastData(const Dim &ds) {
// assert(data != nullptr);
// auto offset = getBroadcastOffset(ds);
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
// }
// VType getBroadcastData(size_t pos) {
// assert(data != nullptr);
// return data[pos % size()];
// }
// size_t getBroadcastOffset(const Dim &ds) {
// assert(ds.size() >= dims.size());
// auto nDim = dims.size();
// auto nBroadcastDim = ds.size() - nDim;
// for (size_t i = 0; i < nDim; ++i)
// if (ds[nBroadcastDim + i] < 0 || ds[nBroadcastDim + i] >=
// dims[i])
// return (size_t)-1;
// size_t idx = 0;
// for (size_t i = 0; i < nDim; ++i)
// idx = idx * dims[i] + ds[nBroadcastDim + i];
// return idx;
// }
// void itInit() { it = Dim(dims.size(), 0); }
// void itReset() {
// itInit();
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// it[i] = 0;
// }
// bool itValid() {
// if (it.size() != dims.size())
// return false;
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
// if (it[i] >= dims[i])
// return false;
// return true;
// }
// const Dim &itGet() { return it; }
// void itNext() {
// auto p = it.size() - 1;
// it[p] += 1;
// while (p >= 1) {
// if (it[p] == dims[p]) {
// it[p] = 0;
// it[--p] += 1;
// } else
// break;
// }
// }
// size_t size() const {
// size_t sz = 1;
// auto dm = dims.size();
// while (dm > 0)
// sz *= dims[--dm];
// return sz;
// }
// TensorType getType() const { return type; }
// void setType(TensorType ty) { type = ty; }
// void print() {
// if (type == Invalid) {
// std::cout << "Invalid tensor" << std::endl;
// return;
// }
// if (data == nullptr || dims.size() == 0) {
// std::cout << "Empty tensor" << std::endl;
// return;
// }
// // TODO: can be uncommented after tensor's compute type is
// correctly set if (computed == NotComputed) {
// std::cout << "Uncomputed tensor" << std::endl;
// return;
// }
// std::cout << "Tensor: " << guid << std::endl;
// auto numDims = dims.size();
// auto dimSzVec = std::vector<int>(numDims, 1);
// dimSzVec[numDims - 1] = dims[numDims - 1];
// for (int i = numDims - 1; i != 0; --i)
// dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1];
// for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
// for (size_t j = 0; j < numDims; ++j) {
// if (i % dimSzVec[j] == 0) {
// std::cout << "[";
// }
// }
// std::cout << data[i];
// for (size_t j = 0; j < numDims; ++j) {
// if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
// std::cout << "]";
// }
// }
// if (i != size() - 1)
// std::cout << ", ";
// if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] -
// 1)
// std::cout << std::endl;
// }
// }
// static inline void initFastrand() {
// assert(omp_get_max_threads() <= 256);
// // srand(0); // constant seed for test
// // align random_seed to avoid false sharing
// for (int i = 0; i < 256 * 16; ++i) {
// // random_seed[i] = rand();
// // constant random seed for test
// random_seed[i] = i;
// }
// random_inited = true;
// }
// static inline int fastrand(int &g_seed) {
// g_seed = (214013 * g_seed + 2531011);
// return (g_seed >> 16) & 0x7FFF;
// }
// std::vector<std::vector<int>> const *getSplittingPoints() const {
// assert(!splittingPoints.empty());
// return &splittingPoints;
// }
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
// assert(!value.empty());
// splittingPoints = value;
// return true;
// }
// void printSplittingPoints() {
// if (splittingPoints.empty())
// printf("Empty SplittingPoints");
// else {
// printf("[");
// for (auto &vs : splittingPoints) {
// printf("[");
// for (auto v : vs)
// printf("%2d,", v);
// printf("],");
// }
// printf("]");
// }
// }
// void initSplittingPoints() {
// splittingPoints.resize(getDims().size()); }
// void printShape();
};
} // namespace it

View File

@ -2,13 +2,13 @@
namespace it { namespace it {
void GraphNode::updateConnection() { void GraphNode::updateConnection() { IT_TODO_HALT(); }
// TODO
}
string GraphNode::toString() const { string GraphNode::toString() const {
std::ostringstream oss; std::ostringstream oss;
oss << "GraphNode: "; oss << "GraphNode operators:\n";
for (const auto &op : ops)
oss << op << "\n";
return oss.str(); return oss.str();
} }

View File

@ -2,10 +2,55 @@
namespace it { namespace it {
string OperatorNode::toString() const { vector<Shape> MatmulNode::computeShape() const {
std::ostringstream oss; Shape ret{args.b, args.m, args.n};
oss << "Operator: "; return {ret};
return oss.str(); }
MatmulNode::MatmulNode(Tensor A, Tensor B, Tensor C, bool transA, bool transB,
Tensor bias, ActType act)
: OperatorNode({A, B, bias}, {C}), args{.b = A->getDims()[0],
.m = transA ? A->getDims()[2]
: A->getDims()[1],
.n = transB ? B->getDims()[1]
: B->getDims()[2],
.k = transA ? A->getDims()[1]
: A->getDims()[2],
.transA = transA,
.transB = transB,
.act = act} {
IT_ASSERT(checkValid(inputs));
}
string MatmulNode::toString() const {
std::ostringstream os;
MatmulArgs args = getArgs();
os << "Matmul([" << (args.transA ? "A^T" : "A") << ","
<< (args.transB ? "B^T" : "B") << ",act=" << (int)args.act
<< "],A=" << inputs[0]->getGuid() << ",B=" << inputs[1]->getGuid()
<< ",C=" << outputs[0]->getGuid() << ")";
return os.str();
}
bool MatmulNode::checkValid(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1];
// if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight)
// return false;
IT_ASSERT(A->getDims().size() == 3 && B->getDims().size() == 3);
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
IT_ASSERT((args.transA ? A->getDims()[1] : A->getDims()[2]) ==
(args.transB ? B->getDims()[2] : B->getDims()[1]));
// if (A->getDims().size() != 3 || B->getDims().size() != 3) {
// return false;
// }
// if (A->getDims()[0] != B->getDims()[0]) {
// return false;
// }
// if ((args.transA ? A->getDims()[1] : A->getDims()[2]) !=
// (args.transB ? B->getDims()[2] : B->getDims()[1])) {
// return false;
// }
return true;
} }
} // namespace it } // namespace it

View File

@ -1,8 +1,29 @@
#include <core/tensor.h> #include <core/tensor.h>
namespace it { namespace it {
string TensorBaseNode::toString() const { TensorNode::TensorNode(const Shape &shape, DataType dtype)
return "TensorBaseNode " + std::to_string(guid); : TensorBaseNode(shape.size(), dtype), shape(shape) {}
VType TensorNode::getData(const Shape &pos) const {
return getData(getOffset(pos));
}
string TensorNode::toString() const {
return "TensorNode " + std::to_string(guid);
}
size_t TensorNode::getOffset(const Shape &pos) const {
auto nDim = pos.size();
IT_ASSERT(shape.size() == nDim);
if (pos.empty())
return 0;
for (size_t i = 0; i < nDim; ++i)
IT_ASSERT(pos[i] < 0 || pos[i] >= shape[i]);
size_t idx = pos[0];
size_t dm = 0;
while (++dm < nDim)
idx = idx * shape[dm] + pos[dm];
return idx;
} }
}; // namespace it }; // namespace it

9
src/core/tensor_base.cc Normal file
View File

@ -0,0 +1,9 @@
#include <core/tensor_base.h>
namespace it {
TensorBaseNode::TensorBaseNode(int dim, DataType dtype)
: dim(dim), dtype(dtype) {}
VType TensorBaseNode::getData(size_t offset) const { return data->at(offset); }
}; // namespace it

View File

@ -5,7 +5,10 @@ namespace it {
TEST(Graph, build) { TEST(Graph, build) {
Graph g = make_ref<GraphNode>(); Graph g = make_ref<GraphNode>();
g->addOp(make_ref<OperatorNode>(TensorVec{}, TensorVec{})); Tensor i0 = g->addTensor({1, 2, 3});
Tensor w0 = g->addTensor({1, 3, 4});
Tensor o0 = g->addTensor({1, 2, 4});
g->addOp(make_ref<MatmulNode>(i0, w0, o0));
g->print(); g->print();
} }