From 48847958d01e75521d3323bf561762ada48201f1 Mon Sep 17 00:00:00 2001 From: constroy Li Date: Fri, 18 Aug 2023 12:17:47 +0800 Subject: [PATCH] impl sqrt on CUDA (#109) * impl sqrt on CUDA fix parser of Gather and ReduceMean * fix test_gather * fix test_cuda_gather * impl sqrt cpu and add sqrt to test_cuda_unary * cuda_unary supports arbitary shapes * fix SplitOp with dim=-1 * fix SplitOp with dim=-1 --- Makefile | 4 ++-- include/core/graph_handler.h | 1 + include/cuda/cuda_unary.h | 17 ++++++++++------- pyinfinitensor/src/pyinfinitensor/onnx.py | 10 ++++++++-- src/core/graph_handler.cc | 1 + src/ffi/ffi_infinitensor.cc | 1 + src/kernels/cpu/unary.cc | 6 ++++++ src/kernels/cuda/unary.cc | 2 ++ src/kernels/cuda/unary.cu | 13 +++++++++++++ src/operators/gather.cc | 2 +- src/operators/split.cc | 10 +++------- test/kernels/cuda/test_cuda_gather.cc | 18 +++++++++--------- test/kernels/cuda/test_cuda_unary.cc | 5 +++++ test/operators/test_gather.cc | 4 ++-- test/operators/test_split.cc | 14 ++++++++++++++ 15 files changed, 78 insertions(+), 30 deletions(-) diff --git a/Makefile b/Makefile index 59842e9e..0fcc9070 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY : build clean format install-python test-cpp test-onnx -TYPE ?= release +TYPE ?= Release CUDA ?= OFF BANG ?= OFF INTELCPU ?= off @@ -30,7 +30,7 @@ format: install-python: build cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor - pip install pyinfinitensor/ + pip install -e pyinfinitensor/ test-cpp: @echo diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 48c79a6e..fa704442 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -47,6 +47,7 @@ class GraphHandlerObj { Tensor tanh(Tensor x, Tensor y); Tensor softmax(Tensor x, Tensor y, int axis); Tensor abs(Tensor x, Tensor y); + Tensor sqrt(Tensor x, Tensor y); Tensor shape(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y); Tensor flatten(Tensor s, Tensor y, int axis); diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index c11912dc..99f73009 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -3,29 +3,32 @@ #include "operators/unary.h" namespace infini { +// TODO(constroy): num should be size_t. void softmax_kernel(float *input, float *output, int num); void relu_kernel(float *input, float *output, int num); void sigmoid_kernel(float *input, float *output, int num); void tanh_kernel(float *input, float *output, int num); void abs_kernel(float *input, float *output, int num); +void sqrt_kernel(float *input, float *output, int num); void unary_kernel(const Operator &_op) { auto op = as(_op); float *const inputData = (op->getInputs(0)->getRawDataPtr()); float *const outputData = (op->getOutput()->getRawDataPtr()); - auto dim = op->getInputs(0)->getDims(); - int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + size_t num = op->getOutput()->size(); if (op->getOpType() == OpType::Softmax) - softmax_kernel(inputData, outputData, n * c * h * w); + softmax_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Relu) - relu_kernel(inputData, outputData, n * c * h * w); + relu_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Sigmoid) - sigmoid_kernel(inputData, outputData, n * c * h * w); + sigmoid_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Tanh) - tanh_kernel(inputData, outputData, n * c * h * w); + tanh_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Abs) - abs_kernel(inputData, outputData, n * c * h * w); + abs_kernel(inputData, outputData, num); + else if (op->getOpType() == OpType::Sqrt) + sqrt_kernel(inputData, outputData, num); else IT_TODO_HALT(); } diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index bc057d9b..1a52a95b 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -377,6 +377,11 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Sqrt": + tensors[node.output[0]] = self.handler.sqrt( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "Shape": tensors[node.output[0]] = self.handler.shape( tensors[node.input[0]], @@ -500,7 +505,7 @@ class OnnxStub: tensors[node.input[1]], tensors.get(node.output[0]), next( - (attr.i for attr in node.attribute if attr.name == "axis") + (attr.i for attr in node.attribute if attr.name == "axis"), 0 ), ) elif node.op_type == "ReduceMean": @@ -521,7 +526,8 @@ class OnnxStub: attr.i for attr in node.attribute if attr.name == "keepdims" - ) + ), + 1 ) != 0, ) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 2f59fd98..be7d5578 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -151,6 +151,7 @@ DEFINE_UNARY_METHOD(relu, Relu) DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(abs, Abs) +DEFINE_UNARY_METHOD(sqrt, Sqrt) DEFINE_UNARY_METHOD(shape, Shape) // see operators/reshape.h diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index f5315e63..9289829f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -344,6 +344,7 @@ void init_graph_builder(py::module &m) { .def("tanh", &Handler::tanh, policy::move) .def("softmax", &Handler::softmax, policy::move) .def("abs", &Handler::abs, policy::move) + .def("sqrt", &Handler::sqrt, policy::move) .def("shape", &Handler::shape, policy::move) .def("identity", &Handler::identity, policy::move) .def("flatten", &Handler::flatten, policy::move) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index c32e5652..755e0a93 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -56,6 +56,10 @@ template class NaiveAbs : public NativeUnary { T doCompute(T val) const override { return val < 0 ? -val : val; } }; +template class NaiveSqrt : public NativeUnary { + T doCompute(T val) const override { return std::sqrt(val); } +}; + template class Clip : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { @@ -91,6 +95,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs, "absNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs, "absNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt, + "sqrtNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, NaiveSoftmax, "softmaxNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index b4ac496f..317f45b8 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -132,6 +132,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn, "Tanh_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, "Abs_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, + "Sqrt_CUDA_Float32"); // REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda, // "Softmax_CUDA_Float32"); diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index b81d8c63..5a1fd272 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -58,6 +58,14 @@ __global__ void _abs_kernel(float *input, float *output, int n) { } } +__global__ void _sqrt_kernel(float *input, float *output, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = sqrt(input[i]); + } +} + namespace infini { void softmax_kernel(float *input, float *output, int num) { @@ -90,5 +98,10 @@ void abs_kernel(float *input, float *output, int num) { int gridsize = (num + block_work_size() - 1) / block_work_size(); _abs_kernel<<>>(input, output, num); } +void sqrt_kernel(float *input, float *output, int num) { + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _sqrt_kernel<<>>(input, output, num); +} }; // namespace infini diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 0441b6ba..aa9ef79d 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -25,7 +25,7 @@ optional> GatherObj::inferShape(const TensorVec &inputs) const { vector GatherObj::inferDataType(const TensorVec &inputs) const { IT_ASSERT(inputs.size() == 2); auto index = inputs[1]; - IT_ASSERT(index->getDType() == DataType::UInt32); + IT_ASSERT(index->getDType() == DataType::Int32); return {inputs[0]->getDType()}; } diff --git a/src/operators/split.cc b/src/operators/split.cc index 45eb1804..be541326 100644 --- a/src/operators/split.cc +++ b/src/operators/split.cc @@ -7,10 +7,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, std::optional outputs, int dim, int num) : OperatorObj(OpType::Split, {input}, ((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))), - dim(dim), num(num), ratio({}) { - int rank = input->getRank(); - dim = get_real_axis(dim, rank); - int dimSize = input->getDims().at(dim); + dim(get_real_axis(dim, input->getRank())), num(num), ratio({}) { + int dimSize = input->getDims().at(this->dim); int pieceSize = dimSize / num; int lastSize = dimSize - pieceSize * num; @@ -28,9 +26,7 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, const vector &ratio) : OperatorObj(OpType::Split, {input}, ((!outputs) ? TensorVec{nullptr} : (*outputs))), - dim(dim), num(-1), ratio(ratio) { - int rank = input->getRank(); - dim = get_real_axis(dim, rank); + dim(get_real_axis(dim, input->getRank())), num(-1), ratio(ratio) { num = ratio.size(); if (!outputs) { TensorVec tmp(num, nullptr); diff --git a/test/kernels/cuda/test_cuda_gather.cc b/test/kernels/cuda/test_cuda_gather.cc index 9dc987ba..33863406 100644 --- a/test/kernels/cuda/test_cuda_gather.cc +++ b/test/kernels/cuda/test_cuda_gather.cc @@ -179,10 +179,10 @@ TEST(Gather, Cuda) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({3, 2}, DataType::Float32); - auto index = gCpu->addTensor({2, 2}, DataType::UInt32); + auto index = gCpu->addTensor({2, 2}, DataType::Int32); gCpu->dataMalloc(); input->copyin(vector{1, 2, 3, 4, 5, 6}); - index->copyin(vector{0, 1, 1, 2}); + index->copyin(vector{0, 1, 1, 2}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -191,7 +191,7 @@ TEST(Gather, Cuda) { auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 0); gCuda->dataMalloc(); inputCuda->copyin(vector{1, 2, 3, 4, 5, 6}); - indexCuda->copyin(vector{0, 1, 1, 2}); + indexCuda->copyin(vector{0, 1, 1, 2}); cudaRuntime->run(gCuda); // cudaPrintTensor(op->getOutput()); @@ -203,10 +203,10 @@ TEST(Gather, Cuda) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({3, 3}, DataType::Float32); - auto index = gCpu->addTensor({1, 2}, DataType::UInt32); + auto index = gCpu->addTensor({1, 2}, DataType::Int32); gCpu->dataMalloc(); input->setData(IncrementalGenerator()); - index->copyin(vector{0, 2}); + index->copyin(vector{0, 2}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -215,7 +215,7 @@ TEST(Gather, Cuda) { auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 1); gCuda->dataMalloc(); inputCuda->setData(IncrementalGenerator()); - indexCuda->copyin(vector{0, 2}); + indexCuda->copyin(vector{0, 2}); cudaRuntime->run(gCuda); // cudaPrintTensor(op->getOutput()); @@ -227,10 +227,10 @@ TEST(Gather, Cuda) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32); - auto index = gCpu->addTensor({3, 1}, DataType::UInt32); + auto index = gCpu->addTensor({3, 1}, DataType::Int32); gCpu->dataMalloc(); input->setData(IncrementalGenerator()); - index->copyin(vector{0, 3, 1}); + index->copyin(vector{0, 3, 1}); auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); @@ -239,7 +239,7 @@ TEST(Gather, Cuda) { auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 1); gCuda->dataMalloc(); inputCuda->setData(IncrementalGenerator()); - indexCuda->copyin(vector{0, 3, 1}); + indexCuda->copyin(vector{0, 3, 1}); cudaRuntime->run(gCuda); // cudaPrintTensor(op->getOutput()); diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 22fed565..78eb95aa 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -45,6 +45,11 @@ TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + // more shapes + testUnary(IncrementalGenerator(), Shape{13}); + testUnary(IncrementalGenerator(), Shape{4, 3}); + testUnary(IncrementalGenerator(), Shape{2, 3, 4, 5, 6}); } } // namespace infini diff --git a/test/operators/test_gather.cc b/test/operators/test_gather.cc index 6d900d6a..f3b9190c 100644 --- a/test/operators/test_gather.cc +++ b/test/operators/test_gather.cc @@ -11,8 +11,8 @@ TEST(Gather, ShapeInference) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); - Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32); - Tensor index = g->addTensor({2, 1, 2}, DataType::UInt32); + Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32); + Tensor index = g->addTensor({2, 1, 2}, DataType::Int32); auto op = g->addOp(i, index, nullptr, 1); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); } diff --git a/test/operators/test_split.cc b/test/operators/test_split.cc index 9914e37f..bd99a8f5 100644 --- a/test/operators/test_split.cc +++ b/test/operators/test_split.cc @@ -20,6 +20,20 @@ TEST(Split, ShapeInfer) { EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32); + + auto op = g->addOp(input, std::nullopt, -1, 4); + EXPECT_EQ(op->numOutputs(), 4); + EXPECT_EQ(op->getOutputs().size(), (size_t)4); + EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3})); + EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); + } + { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime);