From 50862df765a55786d4392a4494af739a7b87f5d0 Mon Sep 17 00:00:00 2001 From: Hardy <100662313+wanghailu0717@users.noreply.github.com> Date: Fri, 10 Nov 2023 17:58:26 +0800 Subject: [PATCH] [Kunlun & CUDA & BANG] add depth2space operator (#178) * add depth2space operator * fix format * add depth2space on cambricon bang * add depth2space on gpu --------- Co-authored-by: wanghailu Co-authored-by: wanghailu Co-authored-by: Haojie Wang --- include/core/graph_handler.h | 2 + include/operators/transpose.h | 29 +++++++++ pyinfinitensor/src/pyinfinitensor/onnx.py | 15 +++++ src/core/graph_handler.cc | 13 ++++ src/ffi/ffi_infinitensor.cc | 12 +++- src/kernels/bang/transpose.cc | 56 ++++++++++++++++ src/kernels/cuda/transpose.cc | 47 ++++++++++++++ src/kernels/kunlun/transpose.cc | 27 ++++++++ src/operators/transpose.cc | 78 +++++++++++++++++++++++ 9 files changed, 278 insertions(+), 1 deletion(-) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 87e909f8..61826893 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -86,6 +86,8 @@ class GraphHandlerObj { Tensor allReduceAvg(Tensor input, Tensor output); TensorVec allGather(Tensor input, std::optional outputs, int n); Tensor broadcast(Tensor input, Tensor output, int root); + Tensor depthToSpace(Tensor input, Tensor output, int blocksize, + std::string mode); //------ modifiers diff --git a/include/operators/transpose.h b/include/operators/transpose.h index c20d0a08..9fcd1617 100644 --- a/include/operators/transpose.h +++ b/include/operators/transpose.h @@ -19,4 +19,33 @@ class TransposeObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; + +class DepthToSpaceObj : public OperatorObj { + public: + DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize, + std::string mode); + OP_CLONE(DepthToSpaceObj); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + int getBlockSize() const { return blockSize; } + int getMode() const { return D2SMode; } + auto getModeString() const { return D2SModeString; } + auto getReshapeDim() const { return reshapeDim; } + auto getTransposeDim() const { return transposeDim; } + auto getOutDim() const { return outDim; } + + private: + int blockSize; + int D2SMode; + std::string D2SModeString; + mutable std::vector reshapeDim = {1, 1, 1, 1, 1, 1}; + mutable std::vector transposeDim = {1, 1, 1, 1, 1, 1}; + mutable std::vector outDim = {1, 1, 1, 1}; + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 6d0da9f8..cc5498f9 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -491,6 +491,21 @@ class OnnxStub: tensors.get(node.output[0]), perm, ) + elif node.op_type == "DepthToSpace": + blocksize = next( + (attr.i for attr in node.attribute if attr.name == "blocksize"), + None, + ) + mode = next( + (attr.s for attr in node.attribute if attr.name == "mode"), + None, + ) + tensors[node.output[0]] = self.handler.depthToSpace( + tensors[node.input[0]], + tensors.get(node.output[0]), + blocksize, + mode, + ) elif node.op_type == "Reshape": dims = _search_shape(model, node.input[0]) size = reduce(lambda acc, x: acc * x, dims) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 77fbcf2d..ddf53884 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -425,6 +425,19 @@ Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition, } } +Tensor GraphHandlerObj::depthToSpace(Tensor input, Tensor output, int blocksize, + std::string mode) { + if (output) { + g->addOpWithOutputs(std::move(input), output, + blocksize, mode); + return output; + } else { + return g + ->addOp(std::move(input), output, blocksize, mode) + ->getOutput(); + } +} + static CastType inferCastType(Tensor input, int to) { auto iType = input->getDType(); auto oType = DataType(to); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index e1a726c3..3612269e 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -111,6 +111,7 @@ void export_values(py::module &m) { .VALUE(OpType, Expand) .VALUE(OpType, Erf) .VALUE(OpType, Where) + .VALUE(OpType, DepthToSpace) .export_values(); #undef VALUE @@ -286,6 +287,13 @@ static int cast_to_of(Operator op) { return castOutputDtype.getIndex(); } +static std::tuple depth_to_space_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::DepthToSpace); + auto depth_to_space = dynamic_cast(op.get()); + return std::make_tuple(depth_to_space->getBlockSize(), + depth_to_space->getModeString()); +} + void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) @@ -321,7 +329,8 @@ void export_functions(py::module &m) { .FUNCTION(split_axis_of) .FUNCTION(gather_axis_of) .FUNCTION(flatten_axis_of) - .FUNCTION(cast_to_of); + .FUNCTION(cast_to_of) + .FUNCTION(depth_to_space_attrs_of); #undef FUNCTION } @@ -477,6 +486,7 @@ void init_graph_builder(py::module &m) { .def("pRelu", &Handler::pRelu, policy::move) .def("clip", &Handler::clip, policy::move) .def("transpose", &Handler::transpose, policy::move) + .def("depthToSpace", &Handler::depthToSpace, policy::move) .def("reshape", &Handler::reshape, policy::move) .def("concat", &Handler::concat, policy::move) .def("split", &Handler::split, policy::move) diff --git a/src/kernels/bang/transpose.cc b/src/kernels/bang/transpose.cc index c87c4c28..ff2783b5 100644 --- a/src/kernels/bang/transpose.cc +++ b/src/kernels/bang/transpose.cc @@ -48,6 +48,62 @@ class TransposeCnnl : public BangKernelWithoutConfig { } }; +class DepthToSpaceCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + auto reshape = op->getReshapeDim(); + auto transpose = op->getTransposeDim(); + auto mode = op->getMode(); + + std::vector permute; + if (mode == 0) { + permute = {0, 3, 4, 1, 5, 2}; + } else { + permute = {0, 1, 4, 2, 5, 3}; + } + + cnnlTensorDescriptor_t aDesc, cDesc; + auto dimout = op->getOutput()->getDims(); + + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, reshape.size(), + reshape.data())); + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + checkCnnlError( + cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + transpose.size(), transpose.data())); + + cnnlTransposeDescriptor_t opDesc; + checkCnnlError(cnnlCreateTransposeDescriptor(&opDesc)); + checkCnnlError( + cnnlSetTransposeDescriptor(opDesc, permute.size(), permute.data())); + + size_t wsSize; + cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), aDesc, opDesc, + &wsSize); + BangPtr wsData = context->getWorkspace(wsSize); + + cnnlStatus_t stat = + cnnlTranspose_v2(context->cnnlHandle(), opDesc, aDesc, aData, cDesc, + cData, wsData, wsSize); + if (stat != CNNL_STATUS_SUCCESS) + return; + + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + checkCnnlError(cnnlDestroyTransposeDescriptor(opDesc)); + } +}; + REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32, TransposeCnnl, "Transpose_cnnl_BANG_Float32"); + +REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DataType::Float32, + DepthToSpaceCnnl, "DepthToSpace_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/kernels/cuda/transpose.cc b/src/kernels/cuda/transpose.cc index 37f97cd9..774cb37f 100644 --- a/src/kernels/cuda/transpose.cc +++ b/src/kernels/cuda/transpose.cc @@ -43,7 +43,54 @@ class TransposeCuda : public CudaKernelWithoutConfig { } }; +class DepthToSpaceCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + auto input = op->getInputs(0); + auto output = op->getOutput(); + void *const inputData = input->getRawDataPtr(); + void *const outputData = output->getRawDataPtr(); + const auto &reshape = op->getReshapeDim(); + const auto &transpose = op->getTransposeDim(); + auto mode = op->getMode(); + + std::vector perm; + if (mode == 0) { + perm = {0, 3, 4, 1, 5, 2}; + } else { + perm = {0, 1, 4, 2, 5, 3}; + } + + int size = input->size(); + int nDims = reshape.size(); + + // Compute strides + SmallArray strides, buffer; + IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + int curStride = 1; + for (int i = nDims - 1; i >= 0; --i) { + buffer.data[i] = curStride; + curStride *= reshape[i]; + } + for (int i = 0; i < nDims; ++i) { + strides.data[i] = buffer.data[perm[i]]; + } + + SmallArray outputDims; + for (int i = 0; i < nDims; ++i) { + outputDims.data[i] = transpose[i]; + } + + transpose_kernel((float *)inputData, (float *)outputData, nDims, size, + strides, outputDims); + } +}; + REGISTER_KERNEL(Device::CUDA, OpType::Transpose, DataType::Float32, TransposeCuda, "Transpose_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::DepthToSpace, DataType::Float32, + DepthToSpaceCuda, "DepthToSpace_CUDA_Float32"); } // namespace infini diff --git a/src/kernels/kunlun/transpose.cc b/src/kernels/kunlun/transpose.cc index 443df8d9..817c32e2 100644 --- a/src/kernels/kunlun/transpose.cc +++ b/src/kernels/kunlun/transpose.cc @@ -27,6 +27,33 @@ class TransposeXdnn : public KUNLUNKernelWithoutConfig { } }; +class DepthToSpaceXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto reshape = op->getReshapeDim(); + auto mode = op->getMode(); + std::vector permute; + if (mode == 0) { + permute = {0, 3, 4, 1, 5, 2}; + } else { + permute = {0, 1, 4, 2, 5, 3}; + } + auto ret = baidu::xpu::api::transpose( + context->KUNLUNHandle(), (float *)aData, (float *)cData, reshape, + permute); + assert(ret == 0); + return; + } +}; + REGISTER_KERNEL(Device::KUNLUN, OpType::Transpose, DataType::Float32, TransposeXdnn, "Transpose_xdnn_KUNLUN_Float32"); +REGISTER_KERNEL(Device::KUNLUN, OpType::DepthToSpace, DataType::Float32, + DepthToSpaceXdnn, "DepthToSpace_xdnn_KUNLUN_Float32"); }; // namespace infini diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index 9a457647..f4c6a28d 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -53,4 +53,82 @@ vector TransposeObj::getOpAttrVector() const { return {type.underlying()}; } +DepthToSpaceObj::DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, + int blocksize, std::string mode) + : OperatorObj(OpType::DepthToSpace, {input}, {output}) { + blockSize = blocksize; + D2SMode = 0; + D2SModeString = "DCR"; + if (mode == "CRD") { + D2SMode = 1; + D2SModeString = "CRD"; + } + IT_ASSERT(checkValid(graph)); +} + +optional> +DepthToSpaceObj::inferShape(const TensorVec &inputs) const { + const auto A = inputs[0]; + auto inputDim = A->getDims(); + IT_ASSERT(inputDim.size() == 4); + if (D2SMode == 0) { + reshapeDim[0] = inputDim[0]; + reshapeDim[1] = blockSize; + reshapeDim[2] = blockSize; + reshapeDim[3] = inputDim[1] / (blockSize * blockSize); + reshapeDim[4] = inputDim[2]; + reshapeDim[5] = inputDim[3]; + transposeDim[0] = reshapeDim[0]; + transposeDim[1] = reshapeDim[3]; + transposeDim[2] = reshapeDim[4]; + transposeDim[3] = reshapeDim[1]; + transposeDim[4] = reshapeDim[5]; + transposeDim[5] = reshapeDim[2]; + outDim[0] = inputDim[0]; + outDim[1] = inputDim[1] / (blockSize * blockSize); + outDim[2] = inputDim[2] * blockSize; + outDim[3] = inputDim[3] * blockSize; + } else { + reshapeDim[0] = inputDim[0]; + reshapeDim[1] = inputDim[1] / (blockSize * blockSize); + reshapeDim[2] = blockSize; + reshapeDim[3] = blockSize; + reshapeDim[4] = inputDim[2]; + reshapeDim[5] = inputDim[3]; + transposeDim[0] = reshapeDim[0]; + transposeDim[1] = reshapeDim[1]; + transposeDim[2] = reshapeDim[4]; + transposeDim[3] = reshapeDim[2]; + transposeDim[4] = reshapeDim[5]; + transposeDim[5] = reshapeDim[3]; + outDim[0] = inputDim[0]; + outDim[1] = inputDim[1] / (blockSize * blockSize); + outDim[2] = inputDim[2] * blockSize; + outDim[3] = inputDim[3] * blockSize; + } + + return {{outDim}}; +} + +std::string DepthToSpaceObj::toString() const { + std::ostringstream os; + os << type.toString() << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector DepthToSpaceObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector DepthToSpaceObj::getOpAttrVector() const { + return {type.underlying()}; +} + }; // namespace infini