From c19d6e6bb0379158e81e97d43dde8bae3f5bd3cf Mon Sep 17 00:00:00 2001 From: wanghailu Date: Wed, 4 Jan 2023 09:24:52 +0000 Subject: [PATCH] add det operation --- include/core/operator.h | 2 ++ include/operators/det.h | 21 ++++++++++++ src/kernels/bang/det.cc | 52 ++++++++++++++++++++++++++++++ src/operators/det.cc | 43 ++++++++++++++++++++++++ test/kernels/bang/test_bang_det.cc | 41 +++++++++++++++++++++++ 5 files changed, 159 insertions(+) create mode 100644 include/operators/det.h create mode 100644 src/kernels/bang/det.cc create mode 100644 src/operators/det.cc create mode 100644 test/kernels/bang/test_bang_det.cc diff --git a/include/core/operator.h b/include/core/operator.h index f16d059b..d5ac39c0 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -88,6 +88,7 @@ enum class OpType { FloorModTrunc, Cumsum, Cumprod, + Det, // MemBound = 300, }; @@ -184,6 +185,7 @@ class OpRegistry { FOP(FloorModTrunc); FOP(Cumsum); FOP(Cumprod); + FOP(Det); // FOP(MemBound); default: diff --git a/include/operators/det.h b/include/operators/det.h new file mode 100644 index 00000000..ba02e03a --- /dev/null +++ b/include/operators/det.h @@ -0,0 +1,21 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class DetObj : public OperatorObj { + public: + enum Mode { NormalDet = 0, LogDet }; + DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + Mode getMode() const { return modeValue; } + + private: + Mode modeValue; + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +}; // namespace infini diff --git a/src/kernels/bang/det.cc b/src/kernels/bang/det.cc new file mode 100644 index 00000000..ce821f56 --- /dev/null +++ b/src/kernels/bang/det.cc @@ -0,0 +1,52 @@ +#include "bang/bang_kernel_without_config.h" +#include "bang/bang_runtime.h" +#include "operators/det.h" + +namespace infini { +class DetCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + DetObj::Mode mode = op->getMode(); + cnnlDetMode_t nlMode; + if(mode == DetObj::LogDet) { + nlMode = CNNL_DET_MODE_LOGDET; + } else { + nlMode = CNNL_DET_MODE_DET; + } + cnnlTensorDescriptor_t aDesc, cDesc; + auto dimin = op->getInputs(0)->getDims(); + auto dimout = op->getOutput()->getDims(); + if (dimin.size() != 4 || dimout.size() != 2) + IT_TODO_HALT(); + + int dimin_array[4] = {dimin[0], dimin[1], dimin[2], dimin[3]}; + int dimout_array[2] = {dimout[0], dimout[1]}; + // get inputs + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, 4, dimin_array)); + + // get outputs + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, + CNNL_DTYPE_FLOAT, 2, dimout_array)); + + cnnlStatus_t stat = cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData); + if (stat != CNNL_STATUS_SUCCESS) + return; + + // Destories in BANG does not require sync. But cnnl does not state + // whether sync is required before destories. + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + } +}; + +REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl, + "Det_cnnl_BANG_Float32"); +}; // namespace infini diff --git a/src/operators/det.cc b/src/operators/det.cc new file mode 100644 index 00000000..ef2ed0c2 --- /dev/null +++ b/src/operators/det.cc @@ -0,0 +1,43 @@ +#include "operators/det.h" + +namespace infini { +DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode) + : OperatorObj(OpType::Det, {input}, {output}), modeValue(mode) { + IT_ASSERT(checkValid(graph)); +} + +optional> DetObj::inferShape(const TensorVec &inputs) const { + const auto A = inputs[0]; + auto input = A->getDims(); + int length = input.size(); + if (length == 2) { + std::vector output ={1}; + return {{output}}; + } else { + std::vector output(input.begin(), input.end() - 2); + return {{output}}; + } +} + +std::string DetObj::toString() const { + std::ostringstream os; + os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector DetObj::getWorkloadVector() const { + vector ret{enum_to_underlying(type)}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector DetObj::getOpAttrVector() const { + return {enum_to_underlying(type)}; +} + +}; // namespace infini diff --git a/test/kernels/bang/test_bang_det.cc b/test/kernels/bang/test_bang_det.cc new file mode 100644 index 00000000..e7165baf --- /dev/null +++ b/test/kernels/bang/test_bang_det.cc @@ -0,0 +1,41 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/det.h" + +#include "test.h" + +namespace infini { + +template +void testDet(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu = bangGraph->cloneTensor(inputCpu); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, DetObj::NormalDet); + bangGraph->dataMalloc(); + bangRuntime->run(bangGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // Check + inputCpu->printData(); + outputGpu2Cpu->printData(); + EXPECT_TRUE(1); +} + +TEST(cnnl_Det, run) { + testDet(IncrementalGenerator(), Shape{1, 1, 2, 2}); +} + +} // namespace infini