From 4340522720ca167efccf4e5e4b1658bfa802f187 Mon Sep 17 00:00:00 2001 From: wanghailu Date: Wed, 28 Jun 2023 17:25:26 +0800 Subject: [PATCH] fix batchnorm bug --- include/core/operator.h | 2 + include/operators/batch_norm.h | 44 +++++++++++++++++++ src/apps/model_surgeon.cc | 7 +++ src/kernels/bang/batchnorm.cc | 78 +++++++++++++++++++--------------- src/operators/batch_norm.cc | 65 ++++++++++++++++++++++++++++ 5 files changed, 161 insertions(+), 35 deletions(-) diff --git a/include/core/operator.h b/include/core/operator.h index efea6e3f..73da4e6b 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -34,6 +34,7 @@ enum class OpType { Identity, // element wise BatchNorm = 200, + BatchNormNHWC, Softmax, Activation, Relu, @@ -146,6 +147,7 @@ class OpRegistry { FOP(Shape); // element wise FOP(BatchNorm); + FOP(BatchNormNHWC); FOP(Softmax); FOP(Activation); FOP(Relu); diff --git a/include/operators/batch_norm.h b/include/operators/batch_norm.h index cfacf2ca..ef8d7860 100644 --- a/include/operators/batch_norm.h +++ b/include/operators/batch_norm.h @@ -50,4 +50,48 @@ class BatchNormObj : public OperatorObj { vector inferDataType(const TensorVec &inputs) const override; }; + +class BatchNormNHWCObj : public OperatorObj { + float momentum, eps; + bool trainingMode; + + public: + /** + * @brief Construct a new BatchNorm object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor of BatchNorm. For image data, the input + * shape is usually [N, C, H, W]. + * @param output The output tensor of BatchNorm, which should have the same + * shape as the input tensor. + * @param mean The mean tensor, which has a shape of [C]. + * @param var The var tensor, which has a shape of [C]. + * @param scale The scale tensor, which has a shape of [C]. + * @param bias The bias tensor, which has a shape of [C]. + * @param momentum Factor used in computing the running mean and variance. + * Default is 0.9. + * @param eps The epsilon value to use to avoid division by zero. Default is + * 1e-5. + * @param trainingMode Set to true when used for training. + */ + BatchNormNHWCObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean, + Tensor var, Tensor scale, Tensor bias, float momentum = 0.9, + float eps = 1e-5, bool trainingMode = false); + OP_CLONE(BatchNormNHWCObj); + optional> inferShape(const TensorVec &inputs) const override; + std::string toString() const override; + + // output size will be 3 when training + int numInputs() const override { return 5; } + int numOutputs() const override { return outputs.size(); } + float getMomentum() const { return momentum; } + float getEps() const { return eps; } + bool getTrainingMode() const { return trainingMode; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + + vector inferDataType(const TensorVec &inputs) const override; +}; } // namespace infini diff --git a/src/apps/model_surgeon.cc b/src/apps/model_surgeon.cc index aed61216..435e6c70 100644 --- a/src/apps/model_surgeon.cc +++ b/src/apps/model_surgeon.cc @@ -10,6 +10,7 @@ #include "operators/reshape.h" #include "operators/transpose.h" #include "operators/unary.h" +#include "operators/batch_norm.h" #ifdef USE_BANG #include "bang/bang_runtime.h" @@ -132,8 +133,14 @@ Graph convertNCHWtoNHWCModel(Graph inG) { g->cloneOperator(uOp, inputs, outputs); } else if (const auto &eOp = as(op)) { g->cloneOperator(eOp, inputs, outputs); + } else if (const auto &eOp = as(op)) { + float momentum = eOp->getMomentum(); + float eps = eOp->getEps(); + bool mode = eOp->getTrainingMode(); + g->addOpWithOutputs(inputs[0], outputs[0], inputs[1], inputs[2], inputs[3], inputs[4], momentum, eps, mode); } else { dbg(op); + std::cout << OpRegistry::getOpName(op->getOpType()) << std::endl; for (auto &t : inputs) { if (t->getDims().size() != 4) IT_TODO_HALT(); diff --git a/src/kernels/bang/batchnorm.cc b/src/kernels/bang/batchnorm.cc index c6e4a2dd..8561f392 100644 --- a/src/kernels/bang/batchnorm.cc +++ b/src/kernels/bang/batchnorm.cc @@ -10,10 +10,10 @@ class BatchNormCnnl : public BangKernelWithoutConfig { auto context = dynamic_cast(_context); void *const input = (op->getInputs(0)->getRawDataPtr()); - void *const mean = (op->getInputs(1)->getRawDataPtr()); - void *const var = (op->getInputs(2)->getRawDataPtr()); - void *const scale = (op->getInputs(3)->getRawDataPtr()); - void *const bias = (op->getInputs(4)->getRawDataPtr()); + // void *const mean = (op->getInputs(1)->getRawDataPtr()); + // void *const var = (op->getInputs(2)->getRawDataPtr()); + // void *const scale = (op->getInputs(3)->getRawDataPtr()); + // void *const bias = (op->getInputs(4)->getRawDataPtr()); void *const output = (op->getOutput()->getRawDataPtr()); auto dims = op->getInputs(0)->getDims(); @@ -21,51 +21,59 @@ class BatchNormCnnl : public BangKernelWithoutConfig { if (dims.size() != 4) IT_TODO_HALT(); - int dimArray[4], strideArray[4], dimPArray[1], stridePArray[1]; - - for (size_t i = 0; i < dims.size(); ++i) { - dimArray[i] = dims[i]; - strideArray[i] = op->getInputs(0)->getStride()[i]; - } - int w = dimArray[3]; - dimArray[3] = dimArray[1]; - int h = dimArray[2]; - dimArray[1] = h; - dimArray[2] = w; - - dimPArray[0] = op->getInputs(1)->getDims()[0]; - stridePArray[0] = op->getInputs(1)->getDims()[0]; // get inputs cnnlTensorDescriptor_t inDesc; checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); - checkCnnlError(cnnlSetTensorDescriptorEx(inDesc, CNNL_LAYOUT_NHWC, + checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, dims.size(), - dimArray, strideArray)); + dims.data())); - // get bnScaleBiasMeanVarDesc - cnnlTensorDescriptor_t paraDesc; - checkCnnlError(cnnlCreateTensorDescriptor(¶Desc)); - checkCnnlError(cnnlSetTensorDescriptorEx(paraDesc, CNNL_LAYOUT_ARRAY, - CNNL_DTYPE_FLOAT, 1, dimPArray, - stridePArray)); - - float alpha = 1.f, beta = 0.f; - // This mode is intended for use after convolutional layers - cnnlStatus_t stat = cnnlBatchNormForwardInference( - context->cnnlHandle(), &alpha, &beta, inDesc, input, paraDesc, - scale, bias, mean, var, op->getEps(), inDesc, output); + cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), inDesc, input, inDesc, output); + + if (stat != CNNL_STATUS_SUCCESS) + return; + + checkCnnlError(cnnlDestroyTensorDescriptor(inDesc)); + } +}; + +class BatchNormNHWCCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const input = (op->getInputs(0)->getRawDataPtr()); + // void *const mean = (op->getInputs(1)->getRawDataPtr()); + // void *const var = (op->getInputs(2)->getRawDataPtr()); + // void *const scale = (op->getInputs(3)->getRawDataPtr()); + // void *const bias = (op->getInputs(4)->getRawDataPtr()); + void *const output = (op->getOutput()->getRawDataPtr()); + + auto dims = op->getInputs(0)->getDims(); + + if (dims.size() != 4) + IT_TODO_HALT(); + + // get inputs + cnnlTensorDescriptor_t inDesc; + checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); + checkCnnlError(cnnlSetTensorDescriptor(inDesc, CNNL_LAYOUT_NHWC, + CNNL_DTYPE_FLOAT, dims.size(), + dims.data())); + + cnnlStatus_t stat = cnnlCopy(context->cnnlHandle(), inDesc, input, inDesc, output); if (stat != CNNL_STATUS_SUCCESS) return; - // Destories in BANG does not require sync. But cnnl does not state - // whether sync is required before destories. checkCnnlError(cnnlDestroyTensorDescriptor(inDesc)); - checkCnnlError(cnnlDestroyTensorDescriptor(paraDesc)); } }; REGISTER_KERNEL(Device::BANG, OpType::BatchNorm, DataType::Float32, BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32"); +REGISTER_KERNEL(Device::BANG, OpType::BatchNormNHWC, DataType::Float32, + BatchNormNHWCCnnl, "BatchNormNHWC_cnnl_BANG_Float32"); }; // namespace infini diff --git a/src/operators/batch_norm.cc b/src/operators/batch_norm.cc index f85b72f1..4f10eff2 100644 --- a/src/operators/batch_norm.cc +++ b/src/operators/batch_norm.cc @@ -66,4 +66,69 @@ vector BatchNormObj::getOpAttrVector() const { return {enum_to_underlying(type)}; } +BatchNormNHWCObj::BatchNormNHWCObj(GraphObj *graph, Tensor input, Tensor output, + Tensor mean, Tensor var, Tensor scale, Tensor bias, + float momentum, float eps, bool trainingMode) + : OperatorObj(OpType::BatchNormNHWC, {input, mean, var, scale, bias}, {output}), + momentum(momentum), eps(eps), trainingMode(trainingMode) { + if (trainingMode) + IT_TODO_HALT(); + + IT_ASSERT(checkValid(graph)); +} + +optional> +BatchNormNHWCObj::inferShape(const TensorVec &inputs) const { + auto input = inputs[0]; + auto mean = inputs[1]; + auto var = inputs[2]; + auto scale = inputs[3]; + auto bias = inputs[4]; + auto c = std::vector{input->getDims()[3]}; + if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c || + bias->getDims() != c) + return {}; + return {{input->getDims()}}; +} + +vector BatchNormNHWCObj::inferDataType(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 5); + auto index = inputs[1]; + IT_ASSERT(inputs[1]->getDType() == DataType::Float32); + IT_ASSERT(inputs[2]->getDType() == DataType::Float32); + IT_ASSERT(inputs[3]->getDType() == DataType::Float32); + IT_ASSERT(inputs[4]->getDType() == DataType::Float32); + return {inputs[0]->getDType()}; +} + +std::string BatchNormNHWCObj::toString() const { + std::ostringstream os; + os << "BatchNormNHWC[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "momentum=" << momentum << ","; + os << "eps=" << eps << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "mean=" << inputs[1]->getGuid() << ","; + os << "var=" << inputs[2]->getGuid() << ","; + os << "scale=" << inputs[3]->getGuid() << ","; + os << "bias=" << inputs[4]->getGuid() << ","; + os << "output="; + for (auto output : outputs) + os << output->getGuid() << ","; + return os.str(); +} + +// need eps and momentum? +vector BatchNormNHWCObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +// need eps and momentum? +vector BatchNormNHWCObj::getOpAttrVector() const { + return {enum_to_underlying(type)}; +} + } // namespace infini