From 0707fb6aff670923f49f73dc67bfba352bc9d9fd Mon Sep 17 00:00:00 2001 From: wanghailu Date: Mon, 26 Dec 2022 03:06:34 +0000 Subject: [PATCH] add mseloss operation --- include/core/operator.h | 2 + include/operators/element_wise.h | 17 ++++++++ src/kernels/bang/element_wise.cc | 60 ++++++++++++++++++++++++++ src/operators/element_wise.cc | 44 +++++++++++++++++++ test/kernels/bang/test_bang_mseloss.cc | 55 +++++++++++++++++++++++ 5 files changed, 178 insertions(+) create mode 100644 test/kernels/bang/test_bang_mseloss.cc diff --git a/include/core/operator.h b/include/core/operator.h index ba9fe142..c2b75807 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -72,6 +72,7 @@ enum class OpType { L2Loss, Maximum, Minimum, + MSELoss, // MemBound = 300, }; @@ -152,6 +153,7 @@ class OpRegistry { FOP(L2Loss); FOP(Maximum); FOP(Minimum); + FOP(MSELoss); // FOP(MemBound); default: diff --git a/include/operators/element_wise.h b/include/operators/element_wise.h index 9036e3ba..be8457d5 100644 --- a/include/operators/element_wise.h +++ b/include/operators/element_wise.h @@ -17,6 +17,23 @@ class ElementWiseObj : public OperatorObj { vector getOpAttrVector() const override; }; +class MSELossObj : public OperatorObj { + public: + enum Reduction { None = 0, Sum, Mean }; + MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output); + optional> inferShape(const TensorVec &inputs) const override; + + Reduction getReduction() const { return reductionMode; } + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + + private: + Reduction reductionMode; + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + #define DEFINE_ELEMENT_WISE_OBJ(prefix, type) \ class prefix##Obj : public ElementWiseObj { \ public: \ diff --git a/src/kernels/bang/element_wise.cc b/src/kernels/bang/element_wise.cc index fdee1dfb..8dd84bf8 100644 --- a/src/kernels/bang/element_wise.cc +++ b/src/kernels/bang/element_wise.cc @@ -262,6 +262,64 @@ class MinimumCnnl : public BangKernelWithoutConfig { } }; +class MSELossCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + MSELossObj::Reduction reduction = op->getReduction(); + cnnlTensorDescriptor_t aDesc, bDesc, cDesc; + auto dim = op->getInputs(0)->getDims(); + if (dim.size() != 4) + IT_TODO_HALT(); + + int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]}; + int dim_out[4] ={1,1,1,1}; + // get inputs + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, 4, dim_array)); + + checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); + checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, 4, dim_array)); + + // get outputs + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + if ( reduction == MSELossObj::None ) { + checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, 4, dim_array)); + } else { + checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, + CNNL_DTYPE_FLOAT, 4, dim_out)); + } + cnnlStatus_t stat; + if( reduction == MSELossObj::None ) { + stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc, aData, bDesc, bData, + cDesc, cData); + } else if (reduction == MSELossObj::Sum) { + stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_SUM, aDesc, aData, bDesc, bData, + cDesc, cData); + } else { + stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_MEAN, aDesc, aData, bDesc, bData, + cDesc, cData); + } + + if (stat != CNNL_STATUS_SUCCESS) + return; + + // Destories in BANG does not require sync. But cnnl does not state + // whether sync is required before destories. + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(bDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + } +}; + class AddCnnl : public ElementWiseCnnl { cnnlOpTensorDesc_t getOpType() const override { return CNNL_OP_TENSOR_ADD; } }; @@ -301,6 +359,8 @@ REGISTER_KERNEL(Device::BANG, OpType::Maximum, DataType::Float32, MaximumCnnl, "Maximum_cnnl_BANG_Float32"); REGISTER_KERNEL(Device::BANG, OpType::Minimum, DataType::Float32, MinimumCnnl, "Minimum_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl, + "MSELoss_cnnl_BANG_Float32"); // REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, // ElementWiseBang, // "Pow_Bang_Float32"); diff --git a/src/operators/element_wise.cc b/src/operators/element_wise.cc index bb13586a..6cfd8d1d 100644 --- a/src/operators/element_wise.cc +++ b/src/operators/element_wise.cc @@ -54,4 +54,48 @@ vector ElementWiseObj::getOpAttrVector() const { return {enum_to_underlying(type)}; } + +MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output) + : OperatorObj(OpType::MSELoss, {input0, input1}, {output}), reductionMode(reduction) { + IT_ASSERT(checkValid(graph)); +} + +optional> +MSELossObj::inferShape(const TensorVec &inputs) const { + const auto A = inputs[0], B = inputs[1]; + if (A->getDims().size() != B->getDims().size() || + A->getDims() != B->getDims()) + return {}; + + if (reductionMode == None) { + return {{A->getDims()}}; + } else { + Shape temp = { 1 }; + return {{temp}}; + } +} + +std::string MSELossObj::toString() const { + std::ostringstream os; + os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << vecToString(inputs[1]->getDims()) << ","; + os << "input0=" << inputs[0]->getGuid() << ","; + os << "input1=" << inputs[1]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +// use output dim or inputs dim? +vector MSELossObj::getWorkloadVector() const { + vector ret = outputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector MSELossObj::getOpAttrVector() const { + return {enum_to_underlying(type)}; +} + }; // namespace infini diff --git a/test/kernels/bang/test_bang_mseloss.cc b/test/kernels/bang/test_bang_mseloss.cc new file mode 100644 index 00000000..5d474176 --- /dev/null +++ b/test/kernels/bang/test_bang_mseloss.cc @@ -0,0 +1,55 @@ +#include "bang/bang_runtime.h" +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +template +void testMSELoss( + const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto bangRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu1 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu1->dataMalloc(); + inputCpu1->setData(generator); + Tensor inputCpu2 = + make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu2->dataMalloc(); + inputCpu2->setData(generator); + + // GPU + Graph bangGraph = make_ref(bangRuntime); + auto inputGpu1 = bangGraph->cloneTensor(inputCpu1); + auto inputGpu2 = bangGraph->cloneTensor(inputCpu2); + auto gpuOp1 = bangGraph->addOp(inputGpu1, inputGpu2, MSELossObj::None, nullptr); + auto gpuOp2 = bangGraph->addOp(inputGpu1, inputGpu2, MSELossObj::Sum, nullptr); + auto gpuOp3 = bangGraph->addOp(inputGpu1, inputGpu2, MSELossObj::Mean, nullptr); + bangGraph->dataMalloc(); + bangRuntime->run(bangGraph); + auto outputGpu1 = gpuOp1->getOutput(); + auto outputGpu2 = gpuOp2->getOutput(); + auto outputGpu3 = gpuOp3->getOutput(); + auto outputGpu2Cpu1 = outputGpu1->clone(cpuRuntime); + auto outputGpu2Cpu2 = outputGpu2->clone(cpuRuntime); + auto outputGpu2Cpu3 = outputGpu3->clone(cpuRuntime); + // Check + outputGpu2Cpu1->printData(); + outputGpu2Cpu2->printData(); + outputGpu2Cpu3->printData(); + EXPECT_TRUE(1); +} + +TEST(cnnl_MSELoss, run) { + testMSELoss(IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +} // namespace infini