forked from jiuyuan/InfiniTensor
fix: mlu 上添加 LeakyRelu,修复 BatchNorm 中维度为 4 的限制,跑通 BGAN
This commit is contained in:
parent
77fd137dcb
commit
23b1612192
|
@ -241,8 +241,50 @@ class HardSigmoidCnnl : public UnaryCnnl {
|
|||
float getScale() const override { return 0.5f; }
|
||||
};
|
||||
|
||||
class LeakyReluCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LeakyReluObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto cDim = op->getOutput()->getDims();
|
||||
auto coef = op->getAlpha();
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
aDim.size(), aDim.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(op->getDType()),
|
||||
cDim.size(), cDim.data()));
|
||||
cnnlActivationDescriptor_t opDesc;
|
||||
checkCnnlError(cnnlCreateActivationDescriptor(&opDesc));
|
||||
checkCnnlError(cnnlSetActivationDescriptor_v5(
|
||||
opDesc, CNNL_ACTIVATION_LEAKYRELU, CNNL_ACTIVATION_HIGH_PRECISION,
|
||||
CNNL_NOT_PROPAGATE_NAN, coef, 0.0, 0.0, 0.0, true));
|
||||
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
cnnlStatus_t stat =
|
||||
cnnlActivationForward(context->cnnlHandle(), opDesc, &alpha, aDesc,
|
||||
aData, &beta, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
checkCnnlError(cnnlDestroyActivationDescriptor(opDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LeakyRelu, LeakyReluCnnl,
|
||||
"LeakyRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
||||
"Sigmoid_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
||||
|
|
|
@ -16,10 +16,14 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
void *const bias = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
void *const output = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
auto outDims = op->getOutput()->getDims();
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto padDims = [](Shape shape) {
|
||||
for (size_t i = shape.size(); i < 4; ++i) {
|
||||
shape.push_back(1);
|
||||
}
|
||||
return shape;
|
||||
};
|
||||
auto dims = padDims(op->getInputs(0)->getDims());
|
||||
auto outDims = padDims(op->getOutput()->getDims());
|
||||
|
||||
int dimsTrans[4] = {dims[0], dims[2], dims[3], dims[1]};
|
||||
int dimsOutTrans[4] = {outDims[0], outDims[2], outDims[3], outDims[1]};
|
||||
|
|
Loading…
Reference in New Issue