ADD add/mul/sub/div/pow operators and CPU/CUDA kernels (#26)

Fix some

remove useless code.

add div/pow kernel

Add add/mul/sub operators.

fix cpu kernel.

add element wise kenerl for cuda.

ADD element wise operator.
This commit is contained in:
wendy12022 2022-09-09 13:43:59 +08:00 committed by GitHub
parent 0409eafb5f
commit 13b7a2604b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 506 additions and 1 deletions

View File

@ -71,6 +71,8 @@ class TensorObj : public TensorBaseObj {
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
}
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
private:
void printDataFloat() const;
void printDataUint32_t() const;
@ -92,6 +94,10 @@ class TensorObj : public TensorBaseObj {
}
return true;
}
Shape getPosByOffset(size_t offset, Shape dim) const;
size_t getOffsetByPos(Shape pos, Shape dim) const;
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {

View File

@ -0,0 +1,25 @@
#pragma once
#include "operators/element_wise.h"
namespace infini {
void div_kernel(float *a, float *b, float *c, int num);
void pow_kernel(float *a, float *b, float *c, int num);
void element_wise_kernel(const Operator &_op) {
auto op = as<ElementWiseObj>(_op);
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>());
float *const cData = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
if (op->getOpType() == OpType::Div)
div_kernel(aData, bData, cData, n * c * h * w);
else if (op->getOpType() == OpType::Pow)
pow_kernel(aData, bData, cData, n * c * h * w);
else
IT_TODO_HALT();
}
}; // namespace infini

View File

@ -0,0 +1,33 @@
#pragma once
#include "core/operator.h"
namespace infini {
class ElementWiseObj : public OperatorObj {
public:
ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1,
Tensor output);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
#define DEFINE_ELEMENT_WISE_OBJ(prefix, type) \
class prefix##Obj : public ElementWiseObj { \
public: \
prefix##Obj(GraphObj *graph, Tensor input0, Tensor input1, \
Tensor output) \
: ElementWiseObj(type, graph, input0, input1, output) {} \
};
DEFINE_ELEMENT_WISE_OBJ(Add, OpType::Add)
DEFINE_ELEMENT_WISE_OBJ(Sub, OpType::Sub)
DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul)
DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div)
DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow)
}; // namespace infini

View File

@ -141,4 +141,35 @@ void TensorObj::copyData(const TensorObj *src) {
runtime->copyBlob(this, src);
}
Shape TensorObj::getPosByOffset(size_t offset, Shape dim) const {
Shape pos = dim;
for (int i = dim.size() - 1; i >= 0; i--) {
pos[i] = offset % dim.at(i);
offset = (offset - pos[i]) / dim.at(i);
}
return pos;
}
size_t TensorObj::getOffsetByPos(Shape pos, Shape dim) const {
int n = dim.size();
size_t offset = pos.at(0);
for (auto i = 1; i < n; i++) {
offset = offset * dim.at(i) + pos.at(i);
}
return offset;
}
size_t TensorObj::getOffsetByBroadcastOffset(size_t bcOffset,
Shape bcDim) const {
Shape bcPos = getPosByOffset(bcOffset, bcDim);
Shape pos = bcPos;
int n = shape.size();
for (auto i = 0; i < n; i++) {
if (shape.at(i) == 1)
pos[i] = 0;
}
return getOffsetByPos(pos, shape);
}
}; // namespace infini

View File

@ -17,4 +17,5 @@ void cudaPrintFloat(float *x, int len) {
cudaPrintFloatImpl<<<1, 1>>>(x, len);
cudaDeviceSynchronize();
}
} // namespace infini

View File

@ -0,0 +1,68 @@
#include "operators/element_wise.h"
#include "core/kernel.h"
namespace infini {
template <typename T> class NativeElementWise : public Kernel {
virtual T doCompute(T val0, T val1) const = 0;
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *context) const override {
auto op = as<ElementWiseObj>(_op);
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
auto outDim = op->getOutput()->getDims();
auto n = op->getOutput()->size();
for (size_t offset = 0; offset < n; offset++) {
// For now,we only process the same dims here, broardcast will be
// considered in the opt layer.
/*auto offset0 =
op->getInputs(0)->getOffsetByBroadcastOffset(offset, outDim);
auto offset1 =
op->getInputs(1)->getOffsetByBroadcastOffset(offset, outDim);
outptr[offset] = doCompute(inptr0[offset0], inptr1[offset1]);*/
outptr[offset] = doCompute(inptr0[offset], inptr1[offset]);
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveAdd : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 + val1; }
};
template <typename T> class NaiveSub : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 - val1; }
};
template <typename T> class NaiveMul : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 * val1; }
};
template <typename T> class NaiveDiv : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); }
};
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
"addNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd<float>,
"addNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub<uint32_t>,
"subNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub<float>,
"subNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul<uint32_t>,
"mulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul<float>,
"mulNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
"divNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
"divNaive_CPU_float32");
}; // namespace infini

View File

@ -0,0 +1,124 @@
#include "operators/element_wise.h"
#include "core/kernel.h"
#include "cuda/cuda_element_wise.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class ElementWiseCudnn : public Kernel {
virtual cudnnOpTensorOp_t getOpType() const = 0;
virtual tuple<float, float, float> getAlphBeta() const {
return {1.f, 1.f, 0.f};
}
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cudnnTensorDescriptor_t aDesc, bDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
aDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
checkCudnnError(cudnnCreateTensorDescriptor(&bDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
bDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&cDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
cDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
// get op descriptor
cudnnOpTensorDescriptor_t opDesc;
checkCudnnError(cudnnCreateOpTensorDescriptor(&opDesc));
checkCudnnError(cudnnSetOpTensorDescriptor(
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
auto [aAlpha, bAlpha, beta] = getAlphBeta();
cudnnStatus_t stat =
cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, aDesc, aData,
&bAlpha, bDesc, bData, &beta, cDesc, cData);
if (stat != CUDNN_STATUS_SUCCESS)
return;
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(aDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(bDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(cDesc));
checkCudnnError(cudnnDestroyOpTensorDescriptor(opDesc));
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class AddCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_ADD; }
};
class SubCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_ADD; }
tuple<float, float, float> getAlphBeta() const override {
return {1.f, -1.f, 0.f};
}
};
class MulCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; }
};
class ElementWiseCuda : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
element_wise_kernel(_op);
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn,
"Add_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn,
"Sub_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn,
"Mul_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
"Div_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
"Pow__CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,38 @@
#include "cuda/cuda_common.h"
#include <math.h>
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _div_kernel(float *x, float *y, float *z, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
z[i] = x[i] / y[i];
}
}
__global__ void _pow_kernel(float *x, float *y, float *z, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
z[i] = pow(x[i], y[i]);
}
}
namespace infini {
void div_kernel(float *a, float *b, float *c, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_div_kernel<<<blocksize, gridsize>>>(a, b, c, num);
}
void pow_kernel(float *a, float *b, float *c, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_pow_kernel<<<blocksize, gridsize>>>(a, b, c, num);
}
}; // namespace infini

View File

@ -0,0 +1,57 @@
#include "operators/element_wise.h"
namespace infini {
ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
Tensor input1, Tensor output)
: OperatorObj(type, {input0, input1}, {output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
ElementWiseObj::inferShape(const TensorVec &inputs) const {
// For now,we only process the same dims here, broardcast will be considered
// in the opt layer.
const auto A = inputs[0], B = inputs[1];
if (A->getDims().size() != B->getDims().size() ||
A->getDims() != B->getDims())
return {};
return {{A->getDims()}};
/*
int n = A->getDims().size();
Shape shape;
for (int i = 0; i < n; i++) {
auto dimA = A->getDims().at(i);
auto dimB = B->getDims().at(i);
if (!(dimA == dimB || dimA == 1 || dimB == 1))
return {};
auto dimI = dimA > dimB ? dimA : dimB;
shape.emplace_back(dimI);
}
return {{shape}};*/
}
std::string ElementWiseObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ",";
os << "input0=" << inputs[0]->getGuid() << ",";
os << "input1=" << inputs[1]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
// use output dim or inputs dim?
vector<int> ElementWiseObj::getWorkloadVector() const {
vector<int> ret = outputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
return ret;
}
vector<int> ElementWiseObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
}; // namespace infini

View File

@ -28,7 +28,7 @@ optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
std::string PoolingObj::toString() const {
std::ostringstream os;
os << "Maxpool[" << getGuid() << "]";
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << "(";
os << "k=[" << kh << "," << kw << "],";
os << "p=[" << ph << "," << pw << "],";

View File

@ -0,0 +1,122 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/element_wise.h"
#include "test.h"
namespace infini {
using ExpectOutput = vector<float>;
TEST(ElementWise, ShapeInference) {
Runtime runtime = CpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({2, 3, 3, 4}, DataType::UInt32);
Tensor i1 = g->addTensor({2, 3, 3, 4}, DataType::UInt32);
auto op = g->addOp<AddObj>(i0, i1, nullptr);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
}
}
/*
template <typename T>
void test_element_wise(
const std::function<void(void *, size_t, DataType)> &generator,
const vector<uint32_t> &ans) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
Tensor i1 = g->addTensor({2, 3, 1, 2}, DataType::UInt32);
auto op = g->addOp<T>(i0, i1, nullptr);
g->dataMalloc();
i0->setData(generator);
i1->setData(generator);
runtime->run(g, true, true);
// check answer
EXPECT_TRUE(op->getOutput()->equalData(ans));
}
TEST(ElementWise, NaiveCPU) {
test_element_wise<AddObj>(IncrementalGenerator(),
vector<uint32_t>{0, 2, 2, 4, 6, 8, 8, 10,
12, 14, 14, 16, 6, 8, 8, 10,
12, 14, 14, 16, 18, 20, 20, 22});
test_element_wise<SubObj>(
IncrementalGenerator(),
vector<uint32_t>{0, 0, 2, 2,
2, 2, 4, 4,
4, 4, 6, 6,
4294967290, 4294967290, 4294967292, 4294967292,
4294967292, 4294967292, 4294967294, 4294967294,
4294967294, 4294967294, 0, 0});
test_element_wise<MulObj>(
IncrementalGenerator(),
vector<uint32_t>{0, 1, 0, 3, 8, 15, 12, 21, 32, 45, 40, 55,
0, 7, 12, 21, 32, 45, 48, 63, 80, 99, 100, 121});
test_element_wise<DivObj>(OneGenerator(),
vector<uint32_t>{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
});
}
*/
template <class T>
void testElementWiseCudnn(
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape, const ExpectOutput &ansVec) {
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Build input data on CPU
Tensor acpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
acpu->dataMalloc();
acpu->setData(generator);
Tensor bcpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
bcpu->dataMalloc();
bcpu->setData(generator);
// Build CUDA graph
Graph g = make_ref<GraphObj>(cudaRuntime);
auto a = g->cloneTensor(acpu);
auto b = g->cloneTensor(bcpu);
auto op = g->addOp<T>(a, b, nullptr);
// allocate CUDA memory
g->dataMalloc();
// Execute on CUDA
cudaRuntime->run(g);
// clone CUDA output to CPU
auto c = op->getOutput();
auto ccpu = c->clone(cpuRuntime);
// cudaPrintTensor(c);
// check results on CPU
EXPECT_TRUE(ccpu->equalData(ansVec));
}
TEST(ElementWise, CuDNN) {
testElementWiseCudnn<AddObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
testElementWiseCudnn<SubObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
testElementWiseCudnn<MulObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
testElementWiseCudnn<DivObj>(
OneGenerator(), Shape{1, 2, 2, 3},
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
testElementWiseCudnn<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
ExpectOutput{1, 1, 4, 27});
}
} // namespace infini