diff --git a/src/kernels/bang/activation.cc b/src/kernels/bang/activation.cc index 4105b168..7d8f102a 100644 --- a/src/kernels/bang/activation.cc +++ b/src/kernels/bang/activation.cc @@ -241,8 +241,50 @@ class HardSigmoidCnnl : public UnaryCnnl { float getScale() const override { return 0.5f; } }; +class LeakyReluCnnl : public BangKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + cnnlTensorDescriptor_t aDesc, cDesc; + auto aDim = op->getInputs(0)->getDims(); + auto cDim = op->getOutput()->getDims(); + auto coef = op->getAlpha(); + + checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + aDim.size(), aDim.data())); + checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); + checkCnnlError(cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()), + cDim.size(), cDim.data())); + cnnlActivationDescriptor_t opDesc; + checkCnnlError(cnnlCreateActivationDescriptor(&opDesc)); + checkCnnlError(cnnlSetActivationDescriptor_v5( + opDesc, CNNL_ACTIVATION_LEAKYRELU, CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, coef, 0.0, 0.0, 0.0, true)); + + float alpha = 1.f, beta = 0.f; + cnnlStatus_t stat = + cnnlActivationForward(context->cnnlHandle(), opDesc, &alpha, aDesc, + aData, &beta, cDesc, cData); + if (stat != CNNL_STATUS_SUCCESS) + return; + checkCnnlError(cnnlDestroyTensorDescriptor(aDesc)); + checkCnnlError(cnnlDestroyTensorDescriptor(cDesc)); + checkCnnlError(cnnlDestroyActivationDescriptor(opDesc)); + } +}; + REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG"); +REGISTER_KERNEL(Device::BANG, OpType::LeakyRelu, LeakyReluCnnl, + "LeakyRelu_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl, "Sigmoid_cnnl_BANG"); REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG"); diff --git a/src/kernels/bang/batchnorm.cc b/src/kernels/bang/batchnorm.cc index 633f0d88..8906fe7a 100644 --- a/src/kernels/bang/batchnorm.cc +++ b/src/kernels/bang/batchnorm.cc @@ -16,10 +16,14 @@ class BatchNormCnnl : public BangKernelWithoutConfig { void *const bias = (op->getInputs(4)->getRawDataPtr()); void *const output = (op->getOutput()->getRawDataPtr()); - auto dims = op->getInputs(0)->getDims(); - auto outDims = op->getOutput()->getDims(); - if (dims.size() != 4) - IT_TODO_HALT(); + auto padDims = [](Shape shape) { + for (size_t i = shape.size(); i < 4; ++i) { + shape.push_back(1); + } + return shape; + }; + auto dims = padDims(op->getInputs(0)->getDims()); + auto outDims = padDims(op->getOutput()->getDims()); int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]}; int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};