fix: mlu 上添加 LeakyRelu,修复 BatchNorm 中维度为 4 的限制,跑通 BGAN

This commit is contained in:
Zhang Bolun 2024-04-25 17:05:06 +08:00 committed by zhangyunze
parent 77fd137dcb
commit 23b1612192
2 changed files with 50 additions and 4 deletions

View File

@ -241,8 +241,50 @@ class HardSigmoidCnnl : public UnaryCnnl {
float getScale() const override { return 0.5f; }
};
class LeakyReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LeakyReluObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto aDim = op->getInputs(0)->getDims();
auto cDim = op->getOutput()->getDims();
auto coef = op->getAlpha();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
aDim.size(), aDim.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
cDim.size(), cDim.data()));
cnnlActivationDescriptor_t opDesc;
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
checkCnnlError(cnnlSetActivationDescriptor_v5(
opDesc, CNNL_ACTIVATION_LEAKYRELU, CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN, coef, 0.0, 0.0, 0.0, true));
float alpha = 1.f, beta = 0.f;
cnnlStatus_t stat =
cnnlActivationForward(context->cnnlHandle(), opDesc, &alpha, aDesc,
aData, &beta, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
checkCnnlError(cnnlDestroyActivationDescriptor(opDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::LeakyRelu, LeakyReluCnnl,
"LeakyRelu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
"Sigmoid_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");

View File

@ -16,10 +16,14 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
void *const output = (op->getOutput()->getRawDataPtr<void *>());
auto dims = op->getInputs(0)->getDims();
auto outDims = op->getOutput()->getDims();
if (dims.size() != 4)
IT_TODO_HALT();
auto padDims = [](Shape shape) {
for (size_t i = shape.size(); i < 4; ++i) {
shape.push_back(1);
}
return shape;
};
auto dims = padDims(op->getInputs(0)->getDims());
auto outDims = padDims(op->getOutput()->getDims());
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};