diff --git a/include/core/tensor.h b/include/core/tensor.h index 41de4168..07d4e3b4 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -71,6 +71,8 @@ class TensorObj : public TensorBaseObj { return equalDataImpl(getRawDataPtr(), dataVector.data(), size()); } + size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const; + private: void printDataFloat() const; void printDataUint32_t() const; @@ -92,6 +94,10 @@ class TensorObj : public TensorBaseObj { } return true; } + + Shape getPosByOffset(size_t offset, Shape dim) const; + size_t getOffsetByPos(Shape pos, Shape dim) const; + // void setDims(const Dim &dms) { dims = dms; } // bool dataRand(int seed = 0) { diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h new file mode 100644 index 00000000..b6dede22 --- /dev/null +++ b/include/cuda/cuda_element_wise.h @@ -0,0 +1,25 @@ +#pragma once + +#include "operators/element_wise.h" + +namespace infini { +void div_kernel(float *a, float *b, float *c, int num); +void pow_kernel(float *a, float *b, float *c, int num); + +void element_wise_kernel(const Operator &_op) { + auto op = as(_op); + float *const aData = (op->getInputs(0)->getRawDataPtr()); + float *const bData = (op->getInputs(1)->getRawDataPtr()); + float *const cData = (op->getOutput()->getRawDataPtr()); + + auto dim = op->getInputs(0)->getDims(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + if (op->getOpType() == OpType::Div) + div_kernel(aData, bData, cData, n * c * h * w); + else if (op->getOpType() == OpType::Pow) + pow_kernel(aData, bData, cData, n * c * h * w); + else + IT_TODO_HALT(); +} + +}; // namespace infini \ No newline at end of file diff --git a/include/operators/element_wise.h b/include/operators/element_wise.h new file mode 100644 index 00000000..7111d50a --- /dev/null +++ b/include/operators/element_wise.h @@ -0,0 +1,33 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class ElementWiseObj : public OperatorObj { + public: + ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1, + Tensor output); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +#define DEFINE_ELEMENT_WISE_OBJ(prefix, type) \ + class prefix##Obj : public ElementWiseObj { \ + public: \ + prefix##Obj(GraphObj *graph, Tensor input0, Tensor input1, \ + Tensor output) \ + : ElementWiseObj(type, graph, input0, input1, output) {} \ + }; + +DEFINE_ELEMENT_WISE_OBJ(Add, OpType::Add) +DEFINE_ELEMENT_WISE_OBJ(Sub, OpType::Sub) +DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul) +DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div) +DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow) +}; // namespace infini \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 18fbdf3b..3b9dbb47 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -141,4 +141,35 @@ void TensorObj::copyData(const TensorObj *src) { runtime->copyBlob(this, src); } +Shape TensorObj::getPosByOffset(size_t offset, Shape dim) const { + Shape pos = dim; + for (int i = dim.size() - 1; i >= 0; i--) { + pos[i] = offset % dim.at(i); + offset = (offset - pos[i]) / dim.at(i); + } + return pos; +} + +size_t TensorObj::getOffsetByPos(Shape pos, Shape dim) const { + int n = dim.size(); + size_t offset = pos.at(0); + for (auto i = 1; i < n; i++) { + offset = offset * dim.at(i) + pos.at(i); + } + return offset; +} + +size_t TensorObj::getOffsetByBroadcastOffset(size_t bcOffset, + Shape bcDim) const { + Shape bcPos = getPosByOffset(bcOffset, bcDim); + + Shape pos = bcPos; + int n = shape.size(); + for (auto i = 0; i < n; i++) { + if (shape.at(i) == 1) + pos[i] = 0; + } + return getOffsetByPos(pos, shape); +} + }; // namespace infini \ No newline at end of file diff --git a/src/cuda/cuda_utility.cu b/src/cuda/cuda_utility.cu index ca0a2d01..cfbdcb9f 100644 --- a/src/cuda/cuda_utility.cu +++ b/src/cuda/cuda_utility.cu @@ -17,4 +17,5 @@ void cudaPrintFloat(float *x, int len) { cudaPrintFloatImpl<<<1, 1>>>(x, len); cudaDeviceSynchronize(); } + } // namespace infini \ No newline at end of file diff --git a/src/kernels/cpu/element_wise.cc b/src/kernels/cpu/element_wise.cc new file mode 100644 index 00000000..1e79d269 --- /dev/null +++ b/src/kernels/cpu/element_wise.cc @@ -0,0 +1,68 @@ +#include "operators/element_wise.h" +#include "core/kernel.h" + +namespace infini { +template class NativeElementWise : public Kernel { + virtual T doCompute(T val0, T val1) const = 0; + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *context) const override { + auto op = as(_op); + T *inptr0 = op->getInputs(0)->getRawDataPtr(); + T *inptr1 = op->getInputs(1)->getRawDataPtr(); + T *outptr = op->getOutput()->getRawDataPtr(); + + auto outDim = op->getOutput()->getDims(); + auto n = op->getOutput()->size(); + for (size_t offset = 0; offset < n; offset++) { + // For now,we only process the same dims here, broardcast will be + // considered in the opt layer. + /*auto offset0 = + op->getInputs(0)->getOffsetByBroadcastOffset(offset, outDim); + auto offset1 = + op->getInputs(1)->getOffsetByBroadcastOffset(offset, outDim); + outptr[offset] = doCompute(inptr0[offset0], inptr1[offset1]);*/ + outptr[offset] = doCompute(inptr0[offset], inptr1[offset]); + } + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + compute(op, {}, context); + } + + PerfRecord tune(const Operator &op, + const RuntimeObj *context) const override { + PerfRecord perfrcd(timeit([&]() { compute(op, context); })); + return perfrcd; + } +}; + +template class NaiveAdd : public NativeElementWise { + T doCompute(T val0, T val1) const override { return val0 + val1; } +}; +template class NaiveSub : public NativeElementWise { + T doCompute(T val0, T val1) const override { return val0 - val1; } +}; +template class NaiveMul : public NativeElementWise { + T doCompute(T val0, T val1) const override { return val0 * val1; } +}; +template class NaiveDiv : public NativeElementWise { + T doCompute(T val0, T val1) const override { return (T)(val0 / val1); } +}; + +REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd, + "addNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd, + "addNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub, + "subNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub, + "subNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul, + "mulNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul, + "mulNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv, + "divNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv, + "divNaive_CPU_float32"); +}; // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc new file mode 100644 index 00000000..7666fb10 --- /dev/null +++ b/src/kernels/cuda/element_wise.cc @@ -0,0 +1,124 @@ +#include "operators/element_wise.h" +#include "core/kernel.h" +#include "cuda/cuda_element_wise.h" +#include "cuda/cuda_runtime.h" + +namespace infini { +class ElementWiseCudnn : public Kernel { + virtual cudnnOpTensorOp_t getOpType() const = 0; + virtual tuple getAlphBeta() const { + return {1.f, 1.f, 0.f}; + } + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + cudnnTensorDescriptor_t aDesc, bDesc, cDesc; + auto dim = op->getInputs(0)->getDims(); + if (dim.size() != 4) + IT_TODO_HALT(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + + // get inputs + checkCudnnError(cudnnCreateTensorDescriptor(&aDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + aDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + checkCudnnError(cudnnCreateTensorDescriptor(&bDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + bDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + // get outputs + checkCudnnError(cudnnCreateTensorDescriptor(&cDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + cDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + // get op descriptor + cudnnOpTensorDescriptor_t opDesc; + checkCudnnError(cudnnCreateOpTensorDescriptor(&opDesc)); + checkCudnnError(cudnnSetOpTensorDescriptor( + opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN)); + + auto [aAlpha, bAlpha, beta] = getAlphBeta(); + cudnnStatus_t stat = + cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, aDesc, aData, + &bAlpha, bDesc, bData, &beta, cDesc, cData); + if (stat != CUDNN_STATUS_SUCCESS) + return; + + // Destories in CUDA does not require sync. But cuDNN does not state + // whether sync is required before destories. + checkCudnnError(cudnnDestroyTensorDescriptor(aDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(bDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(cDesc)); + checkCudnnError(cudnnDestroyOpTensorDescriptor(opDesc)); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + compute(_op, {}, _context); + } + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + PerfRecord ret; + auto context = dynamic_cast(_context); + ret.time = timeit([&]() { compute(_op, _context); }, + [&]() { context->sync(); }); + return ret; + } +}; + +class AddCudnn : public ElementWiseCudnn { + cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_ADD; } +}; + +class SubCudnn : public ElementWiseCudnn { + cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_ADD; } + tuple getAlphBeta() const override { + return {1.f, -1.f, 0.f}; + } +}; + +class MulCudnn : public ElementWiseCudnn { + cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; } +}; + +class ElementWiseCuda : public Kernel { + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + element_wise_kernel(_op); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + compute(_op, {}, _context); + } + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + PerfRecord ret; + auto context = dynamic_cast(_context); + ret.time = timeit([&]() { compute(_op, _context); }, + [&]() { context->sync(); }); + return ret; + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn, + "Add_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn, + "Sub_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn, + "Mul_cuDNN_CUDA_Float32"); + +REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda, + "Div_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda, + "Pow__CUDA_Float32"); +}; // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu new file mode 100644 index 00000000..afe429f4 --- /dev/null +++ b/src/kernels/cuda/element_wise.cu @@ -0,0 +1,38 @@ +#include "cuda/cuda_common.h" +#include + +constexpr unsigned int num_threads() { return 32 * 4; } +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } + +__global__ void _div_kernel(float *x, float *y, float *z, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + z[i] = x[i] / y[i]; + } +} + +__global__ void _pow_kernel(float *x, float *y, float *z, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + z[i] = pow(x[i], y[i]); + } +} + +namespace infini { +void div_kernel(float *a, float *b, float *c, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _div_kernel<<>>(a, b, c, num); +} +void pow_kernel(float *a, float *b, float *c, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _pow_kernel<<>>(a, b, c, num); +} + +}; // namespace infini \ No newline at end of file diff --git a/src/operators/element_wise.cc b/src/operators/element_wise.cc new file mode 100644 index 00000000..ae87758d --- /dev/null +++ b/src/operators/element_wise.cc @@ -0,0 +1,57 @@ +#include "operators/element_wise.h" + +namespace infini { +ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, + Tensor input1, Tensor output) + : OperatorObj(type, {input0, input1}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> +ElementWiseObj::inferShape(const TensorVec &inputs) const { + // For now,we only process the same dims here, broardcast will be considered + // in the opt layer. + const auto A = inputs[0], B = inputs[1]; + if (A->getDims().size() != B->getDims().size() || + A->getDims() != B->getDims()) + return {}; + + return {{A->getDims()}}; + /* + int n = A->getDims().size(); + Shape shape; + for (int i = 0; i < n; i++) { + auto dimA = A->getDims().at(i); + auto dimB = B->getDims().at(i); + if (!(dimA == dimB || dimA == 1 || dimB == 1)) + return {}; + auto dimI = dimA > dimB ? dimA : dimB; + shape.emplace_back(dimI); + } + return {{shape}};*/ +} + +std::string ElementWiseObj::toString() const { + std::ostringstream os; + os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << vecToString(inputs[1]->getDims()) << ","; + os << "input0=" << inputs[0]->getGuid() << ","; + os << "input1=" << inputs[1]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +// use output dim or inputs dim? +vector ElementWiseObj::getWorkloadVector() const { + vector ret = outputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector ElementWiseObj::getOpAttrVector() const { + return {enum_to_underlying(type)}; +} + +}; // namespace infini \ No newline at end of file diff --git a/src/operators/pooling.cc b/src/operators/pooling.cc index 0fcc5416..6e87cc94 100644 --- a/src/operators/pooling.cc +++ b/src/operators/pooling.cc @@ -28,7 +28,7 @@ optional> PoolingObj::inferShape(const TensorVec &inputs) const { std::string PoolingObj::toString() const { std::ostringstream os; - os << "Maxpool[" << getGuid() << "]"; + os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; os << "("; os << "k=[" << kh << "," << kw << "],"; os << "p=[" << ph << "," << pw << "],"; diff --git a/test/operators/test_element_wise.cc b/test/operators/test_element_wise.cc new file mode 100644 index 00000000..1bdc5f4e --- /dev/null +++ b/test/operators/test_element_wise.cc @@ -0,0 +1,122 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +TEST(ElementWise, ShapeInference) { + Runtime runtime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({2, 3, 3, 4}, DataType::UInt32); + Tensor i1 = g->addTensor({2, 3, 3, 4}, DataType::UInt32); + auto op = g->addOp(i0, i1, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3, 4})); + } +} +/* +template +void test_element_wise( + const std::function &generator, + const vector &ans) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 3, 2, 2}, DataType::UInt32); + Tensor i1 = g->addTensor({2, 3, 1, 2}, DataType::UInt32); + auto op = g->addOp(i0, i1, nullptr); + + g->dataMalloc(); + i0->setData(generator); + i1->setData(generator); + runtime->run(g, true, true); + // check answer + EXPECT_TRUE(op->getOutput()->equalData(ans)); +} + +TEST(ElementWise, NaiveCPU) { + test_element_wise(IncrementalGenerator(), + vector{0, 2, 2, 4, 6, 8, 8, 10, + 12, 14, 14, 16, 6, 8, 8, 10, + 12, 14, 14, 16, 18, 20, 20, 22}); + test_element_wise( + IncrementalGenerator(), + vector{0, 0, 2, 2, + 2, 2, 4, 4, + 4, 4, 6, 6, + 4294967290, 4294967290, 4294967292, 4294967292, + 4294967292, 4294967292, 4294967294, 4294967294, + 4294967294, 4294967294, 0, 0}); + test_element_wise( + IncrementalGenerator(), + vector{0, 1, 0, 3, 8, 15, 12, 21, 32, 45, 40, 55, + 0, 7, 12, 21, 32, 45, 48, 63, 80, 99, 100, 121}); + test_element_wise(OneGenerator(), + vector{ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }); +} +*/ + +template +void testElementWiseCudnn( + const std::function &generator, + const Shape &shape, const ExpectOutput &ansVec) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor acpu = make_ref(shape, DataType::Float32, cpuRuntime); + acpu->dataMalloc(); + acpu->setData(generator); + + Tensor bcpu = make_ref(shape, DataType::Float32, cpuRuntime); + bcpu->dataMalloc(); + bcpu->setData(generator); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto a = g->cloneTensor(acpu); + auto b = g->cloneTensor(bcpu); + auto op = g->addOp(a, b, nullptr); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto c = op->getOutput(); + auto ccpu = c->clone(cpuRuntime); + // cudaPrintTensor(c); + // check results on CPU + EXPECT_TRUE(ccpu->equalData(ansVec)); +} + +TEST(ElementWise, CuDNN) { + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); + + testElementWiseCudnn( + OneGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + testElementWiseCudnn(IncrementalGenerator(), Shape{1, 2, 2, 1}, + ExpectOutput{1, 1, 4, 27}); +} + +} // namespace infini \ No newline at end of file