Add: kernel registry and naive Matmul kernel

This commit is contained in:
Liyan Zheng 2022-08-06 15:58:40 +08:00
parent 559be5866d
commit 6c356d5b42
14 changed files with 267 additions and 178 deletions

View File

@ -46,7 +46,6 @@ if(BUILD_TEST)
endif()
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
# file(GLOB_RECURSE TEST test/*.cc)
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
@ -62,7 +61,6 @@ add_library(InfiniTensor SHARED ${SRC})
if(BUILD_TEST)
enable_testing()
# Build all tests file( GLOB TEST_SOURCES test/test_sg2bmm.cc )
file(GLOB_RECURSE TEST_SOURCES test/*.cc)
foreach(testsourcefile ${TEST_SOURCES})
get_filename_component(testname ${testsourcefile} NAME_WE)

View File

@ -29,14 +29,14 @@ using std::vector;
// Aliases
using dtype = float;
// Utilities
// Metaprogramming utilities
#define _CAT(A, B) A##B
#define _SELECT(NAME, NUM) _CAT(NAME##_, NUM)
#define _GET_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ...) COUNT
#define _VA_SIZE(...) _GET_COUNT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1)
#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__)
// Assert
// Assert: conditions should have no side effect
#define _IT_ASSERT_2(name, info) \
(static_cast<bool>(name) \
? void(0) \
@ -49,4 +49,11 @@ using dtype = float;
#define IT_TODO_HALT(...) IT_ASSERT(false, "Unimplemented")
#define IT_TODO_SKIP(...) puts("Unimplemented " __FILE__ ":" __LINE__)
// Other utilities
// std::to_underlying is avaiable since C++23
template <typename T> auto enum_to_underlying(T e) {
return static_cast<std::underlying_type_t<T>>(e);
}
} // namespace it

View File

@ -23,13 +23,14 @@ class GraphNode : public Object {
// TensorVec &getInputs();
// TensorVec &getOutputs();
Tensor addTensor(Shape dim) {
Tensor tensor = make_ref<TensorNode>(dim);
Tensor addTensor(Shape dim, DataType dtype = DataType::Int32) {
Tensor tensor = make_ref<TensorNode>(dim, dtype);
tensors.emplace_back(tensor);
return tensor;
}
void updateConnection();
void dataMalloc();
// TODO
// bool compute();

56
include/core/kernel.h Normal file
View File

@ -0,0 +1,56 @@
#pragma once
#include "core/common.h"
#include "core/operator.h"
#include "core/tensor.h"
namespace it {
enum class Device { CPU = 1, CUDA };
class Kernel {
public:
Kernel() {}
virtual ~Kernel() {}
virtual void compute(const Operator &op) const = 0;
};
class KernelRegistry {
public:
using Key = std::tuple<Device, OpType, DataType>;
public:
~KernelRegistry() {
for (auto &[k, v] : kernels)
delete v;
}
static KernelRegistry &getInstance() {
static KernelRegistry instance;
return instance;
}
bool registerKernel(const Key &key, Kernel *kernel) {
// TODO: kernels with priority
IT_ASSERT(kernels.find(key) == kernels.end(),
"Kernel already registered");
kernels.emplace(key, kernel);
return true;
}
Kernel *getKernel(Device device, OpType opType, DataType dataType) const {
return kernels.at(Key{device, opType, dataType});
}
private:
std::map<Key, Kernel *> kernels;
};
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, cnt) \
namespace it { \
static const bool _CAT(_register_kernel_, cnt) = \
KernelRegistry::getInstance().registerKernel( \
KernelRegistry::Key{device, opType, dataType}, new kernel()); \
}
#define REGISTER_KERNEL(device, opType, dataType, kernel) \
_REGISTER_KERNEL_1(device, opType, dataType, kernel, __COUNTER__)
} // namespace it

View File

@ -3,7 +3,7 @@
namespace it {
enum OpType {
enum class OpType {
Unknown = 0,
// linear
Conv = 100,
@ -41,7 +41,7 @@ class OpRegistry {
public:
std::string getOpName(OpType opType) {
#define FOP(op) \
case op: \
case OpType::op: \
return #op
switch (opType) {
@ -83,7 +83,7 @@ class OpRegistry {
}
};
enum ActType {
enum class ActType {
None,
Relu,
Sigmoid,
@ -100,26 +100,19 @@ class OperatorNode : public Object {
// vector<WRef<Operator>> successors;
public:
OperatorNode(TensorVec inputs, TensorVec outputs)
: inputs(inputs), outputs(outputs) {}
OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs)
: type(opType), inputs(inputs), outputs(outputs) {}
virtual vector<Shape> computeShape() const = 0;
public: // check Op type
bool isLinearOp() const { return type >= 100 && type < 200; }
bool isElementWiseOp() const { return type >= 200 && type < 300; }
bool isSplitOp() const { return type == Split; }
bool isConcatOp() const { return type == Concat; }
bool isComputeOp() const {
return type == Conv || type == Matmul || type == ConvTrans ||
type == G2BMM || type == GBMML;
}
bool isTransposeOp() const { return type == Transpose; }
bool isReshapeOp() const { return type == Reshape; }
bool isMemBoundOp() const {
return type == MemBound || type == Activation || type == Transpose;
}
bool isLinearOp() const;
bool isElementWiseOp() const;
bool isSplitOp() const;
bool isConcatOp() const;
bool isComputeOp() const;
bool isTransposeOp() const;
bool isReshapeOp() const;
bool isMemBoundOp() const;
public: // getter and setter
// TensorVec getInputs() { return inputs; }
@ -131,6 +124,7 @@ class OperatorNode : public Object {
IT_ASSERT(outputs.size() == 1, "Unimplemented");
return outputs[0];
}
OpType getOpType() const { return type; }
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;
@ -152,7 +146,8 @@ class MatmulNode : public OperatorNode {
public:
MatmulNode(Tensor A, Tensor B, Tensor C, bool transA = false,
bool transB = false, Tensor bias = nullptr, ActType act = None);
bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
std::string toString() const override;
vector<Shape> computeShape() const override;

24
include/core/run_enigne.h Normal file
View File

@ -0,0 +1,24 @@
#include "core/graph.h"
#include "core/kernel.h"
namespace it {
class RunEngine {
public:
RunEngine(Device device) : device(device) {}
~RunEngine() {}
void run(Graph graph) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
Kernel *kernel = kernelRegistry.getKernel(device, op->getOpType(),
DataType::Int32);
kernel->compute(op);
}
}
private:
Device device;
};
} // namespace it

View File

@ -11,22 +11,21 @@ class TensorNode : public TensorBaseNode {
Shape shape;
public:
TensorNode(const Shape &shape, DataType dtype = DataType::Float32);
TensorNode(const Shape &shape, DataType dtype);
virtual ~TensorNode() {}
string toString() const override;
int size();
void dataMalloc(size_t size) {
IT_ASSERT(data == nullptr);
data = make_ref<vector<VType>>(size);
}
size_t size() const;
void dataMalloc();
Shape getDims() const { return shape; }
size_t getOffset(const Shape &ds) const;
using TensorBaseNode::getData;
VType getData(const Shape &pos) const;
void copyData(VType *dptr);
void printData() const;
bool equalData(const Tensor &rhs) const;
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {
@ -47,17 +46,6 @@ class TensorNode : public TensorBaseNode {
// return true;
// }
// bool setData(VType *dptr) {
// if (dptr == nullptr)
// return false;
// auto sz = size();
// #pragma omp parallel for
// for (size_t i = 0; i < sz; ++i)
// data[i] = dptr[i];
// computed = ComputedFull;
// return true;
// }
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
@ -137,60 +125,9 @@ class TensorNode : public TensorBaseNode {
// }
// }
// size_t size() const {
// size_t sz = 1;
// auto dm = dims.size();
// while (dm > 0)
// sz *= dims[--dm];
// return sz;
// }
// TensorType getType() const { return type; }
// void setType(TensorType ty) { type = ty; }
// void print() {
// if (type == Invalid) {
// std::cout << "Invalid tensor" << std::endl;
// return;
// }
// if (data == nullptr || dims.size() == 0) {
// std::cout << "Empty tensor" << std::endl;
// return;
// }
// // TODO: can be uncommented after tensor's compute type is
// correctly set if (computed == NotComputed) {
// std::cout << "Uncomputed tensor" << std::endl;
// return;
// }
// std::cout << "Tensor: " << guid << std::endl;
// auto numDims = dims.size();
// auto dimSzVec = std::vector<int>(numDims, 1);
// dimSzVec[numDims - 1] = dims[numDims - 1];
// for (int i = numDims - 1; i != 0; --i)
// dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1];
// for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
// for (size_t j = 0; j < numDims; ++j) {
// if (i % dimSzVec[j] == 0) {
// std::cout << "[";
// }
// }
// std::cout << data[i];
// for (size_t j = 0; j < numDims; ++j) {
// if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
// std::cout << "]";
// }
// }
// if (i != size() - 1)
// std::cout << ", ";
// if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] -
// 1)
// std::cout << std::endl;
// }
// }
// static inline void initFastrand() {
// assert(omp_get_max_threads() <= 256);
// // srand(0); // constant seed for test

View File

@ -20,13 +20,13 @@ using OpVec = vector<Operator>;
using VType = uint32_t;
enum class DataType {
Float32,
Int32,
};
class TensorBaseNode : public Object {
public:
enum DataType {
Float32,
Int32,
};
// enum TensorType {
// Input,
// Weight,
@ -49,7 +49,8 @@ class TensorBaseNode : public Object {
DataType dtype;
vector<WRef<TensorBaseNode>> inputOf;
WRef<TensorBaseNode> outputOf;
Ref<vector<VType>> data;
// TODO: use a blob instead of vector
Ref<VType[]> data;
// ComputeState computed;
// static int random_seed[256 * 16];
// static bool random_inited;
@ -58,7 +59,7 @@ class TensorBaseNode : public Object {
TensorBaseNode(int dim, DataType dtype);
virtual ~TensorBaseNode() {}
// Ref<vector<VType>> getDataPtr() const { return data; }
Ref<VType[]> getDataPtr() const { return data; }
VType getData(size_t offset) const;
DataType getDType() const { return dtype; }
@ -77,12 +78,6 @@ class TensorBaseNode : public Object {
// Operator *getOutputOf() { return outputOf; }
// std::pair<Operator *, int> getOutputOfWithIndex();
// bool dataMalloc() {
// if (data == nullptr)
// data = new VType[size()];
// return data != nullptr;
// }
// const Dim &getDims() const { return dims; }
// void setDims(const Dim &dms) { dims = dms; }
@ -104,17 +99,6 @@ class TensorBaseNode : public Object {
// return true;
// }
// bool setData(VType *dptr) {
// if (dptr == nullptr)
// return false;
// auto sz = size();
// #pragma omp parallel for
// for (size_t i = 0; i < sz; ++i)
// data[i] = dptr[i];
// computed = ComputedFull;
// return true;
// }
// bool setScalar(VType val) {
// if (data == nullptr || !dims.empty())
// return false;
@ -234,49 +218,6 @@ class TensorBaseNode : public Object {
// TensorType getType() const { return type; }
// void setType(TensorType ty) { type = ty; }
// void print() {
// if (type == Invalid) {
// std::cout << "Invalid tensor" << std::endl;
// return;
// }
// if (data == nullptr || dims.size() == 0) {
// std::cout << "Empty tensor" << std::endl;
// return;
// }
// // TODO: can be uncommented after tensor's compute type is
// correctly set if (computed == NotComputed) {
// std::cout << "Uncomputed tensor" << std::endl;
// return;
// }
// std::cout << "Tensor: " << guid << std::endl;
// auto numDims = dims.size();
// auto dimSzVec = std::vector<int>(numDims, 1);
// dimSzVec[numDims - 1] = dims[numDims - 1];
// for (int i = numDims - 1; i != 0; --i)
// dimSzVec[i - 1] = dimSzVec[i] * dims[i - 1];
// for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
// for (size_t j = 0; j < numDims; ++j) {
// if (i % dimSzVec[j] == 0) {
// std::cout << "[";
// }
// }
// std::cout << data[i];
// for (size_t j = 0; j < numDims; ++j) {
// if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
// std::cout << "]";
// }
// }
// if (i != size() - 1)
// std::cout << ", ";
// if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] -
// 1)
// std::cout << std::endl;
// }
// }
// static inline void initFastrand() {
// assert(omp_get_max_threads() <= 256);
// // srand(0); // constant seed for test

View File

@ -12,4 +12,9 @@ string GraphNode::toString() const {
return oss.str();
}
void GraphNode::dataMalloc() {
for (auto &tensor : tensors)
tensor->dataMalloc();
}
} // namespace it

View File

@ -2,6 +2,33 @@
namespace it {
bool OperatorNode::isLinearOp() const {
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
}
bool OperatorNode::isElementWiseOp() const {
return enum_to_underlying(type) >= 200 && enum_to_underlying(type) < 300;
}
bool OperatorNode::isSplitOp() const { return type == OpType::Split; }
bool OperatorNode::isConcatOp() const { return type == OpType::Concat; }
bool OperatorNode::isComputeOp() const {
return type == OpType::Conv || type == OpType::Matmul ||
type == OpType::ConvTrans || type == OpType::G2BMM ||
type == OpType::GBMML;
}
bool OperatorNode::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorNode::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorNode::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose;
}
vector<Shape> MatmulNode::computeShape() const {
Shape ret{args.b, args.m, args.n};
return {ret};
@ -9,16 +36,14 @@ vector<Shape> MatmulNode::computeShape() const {
MatmulNode::MatmulNode(Tensor A, Tensor B, Tensor C, bool transA, bool transB,
Tensor bias, ActType act)
: OperatorNode({A, B, bias}, {C}), args{.b = A->getDims()[0],
.m = transA ? A->getDims()[2]
: A->getDims()[1],
.n = transB ? B->getDims()[1]
: B->getDims()[2],
.k = transA ? A->getDims()[1]
: A->getDims()[2],
.transA = transA,
.transB = transB,
.act = act} {
: OperatorNode(OpType::Matmul, {A, B, bias}, {C}),
args{.b = A->getDims()[0],
.m = transA ? A->getDims()[2] : A->getDims()[1],
.n = transB ? B->getDims()[1] : B->getDims()[2],
.k = transA ? A->getDims()[1] : A->getDims()[2],
.transA = transA,
.transB = transB,
.act = act} {
IT_ASSERT(checkValid(inputs));
}

View File

@ -4,6 +4,12 @@ namespace it {
TensorNode::TensorNode(const Shape &shape, DataType dtype)
: TensorBaseNode(shape.size(), dtype), shape(shape) {}
void TensorNode::dataMalloc() {
IT_ASSERT(data == nullptr);
// initialized to zero
data.reset(reinterpret_cast<VType *>(calloc(size(), sizeof(VType))));
}
VType TensorNode::getData(const Shape &pos) const {
return getData(getOffset(pos));
}
@ -26,4 +32,59 @@ size_t TensorNode::getOffset(const Shape &pos) const {
return idx;
}
size_t TensorNode::size() const {
size_t ret = 1;
for (const auto &d : shape)
ret *= d;
return ret;
}
void TensorNode::copyData(VType *dptr) {
IT_ASSERT(data != nullptr);
size_t sz = size();
#pragma omp parallel for
for (size_t i = 0; i < sz; ++i) {
data[i] = dptr[i];
}
}
void TensorNode::printData() const {
IT_ASSERT(data != nullptr);
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
for (size_t j = 0; j < numDims; ++j) {
if (i % dimSzVec[j] == 0) {
std::cout << "[";
}
}
std::cout << data[i];
for (size_t j = 0; j < numDims; ++j) {
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
std::cout << "]";
}
}
if (i != size() - 1)
std::cout << ", ";
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
std::cout << std::endl;
}
}
bool TensorNode::equalData(const Tensor &rhs) const {
IT_ASSERT(data != nullptr);
IT_ASSERT(rhs->data != nullptr);
if (shape != rhs->getDims())
return false;
size_t sz = size();
for (size_t i = 0; i < sz; ++i)
if (data[i] != rhs->data[i])
return false;
return true;
}
}; // namespace it

View File

@ -4,6 +4,6 @@ namespace it {
TensorBaseNode::TensorBaseNode(int dim, DataType dtype)
: dim(dim), dtype(dtype) {}
VType TensorBaseNode::getData(size_t offset) const { return data->at(offset); }
VType TensorBaseNode::getData(size_t offset) const { return data[offset]; }
}; // namespace it

30
src/kerels/cpu/matmul.cc Normal file
View File

@ -0,0 +1,30 @@
#include "core/kernel.h"
namespace it {
template <typename T> class NaiveMatmul : public Kernel {
void compute(const Operator &_op) const override {
auto op = as<MatmulNode>(_op);
T *A = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get());
T *B = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get());
T *C = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get());
const auto args = op->getArgs();
IT_ASSERT(args.transA == false && args.transB == false);
IT_ASSERT(args.act == ActType::None);
const int M = args.m, N = args.n, K = args.k;
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
for (int k = 0; k < K; k++) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32,
NaiveMatmul<uint32_t>);
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32,
NaiveMatmul<float>);
} // namespace it

View File

@ -1,15 +1,24 @@
#include "core/graph.h"
#include "core/run_enigne.h"
#include "test.h"
namespace it {
TEST(Graph, build) {
Graph g = make_ref<GraphNode>();
Tensor i0 = g->addTensor({1, 2, 3});
Tensor w0 = g->addTensor({1, 3, 4});
Tensor o0 = g->addTensor({1, 2, 4});
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32);
g->dataMalloc();
i0->copyData(vector<VType>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data());
w0->copyData(vector<VType>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data());
g->addOp(make_ref<MatmulNode>(i0, w0, o0));
g->print();
RunEngine(Device::CPU).run(g);
// check answer
auto ans = make_ref<TensorNode>(Shape{1, 2, 4}, DataType::Int32);
ans->dataMalloc();
ans->copyData(vector<VType>{38, 44, 50, 56, 83, 98, 113, 128}.data());
EXPECT_TRUE(o0->equalData(ans));
}
} // namespace it