feat: kunlun 上添加LeakyRelu,修复BatchNorm中维度为4的限制,跑通bgan

This commit is contained in:
weijie01 2024-04-28 10:42:04 +08:00 committed by zhangyunze
parent 23b1612192
commit 36baae7615
2 changed files with 31 additions and 6 deletions

View File

@ -19,13 +19,17 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
auto dims = op->getInputs(0)->getDims();
if (dims.size() != 4)
IT_TODO_HALT();
int n, c, h, w;
if (dims.size() != 4){
h = 1;
w = 1;
}
w = dims[3];
h = dims[2];
c = dims[1];
n = dims[0];
int w = dims[3];
int h = dims[2];
int c = dims[1];
int n = dims[0];
auto ret = xdnn::batch_norm_infer<float>(
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,

21
src/kernels/kunlun/unary.cc Normal file → Executable file
View File

@ -21,6 +21,26 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig {
}
};
class LeakyReluXdnn : public KUNLUNKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LeakyReluObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto len = op->getInputs(0)->size();
auto alpha = op->getAlpha();
auto ret = xdnn::leaky_relu<float>(context->KUNLUNHandle(),
(float *const)aData, (float *)cData,
len, alpha);
assert(ret == 0);
return;
}
};
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
@ -552,6 +572,7 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig {
};
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
REGISTER_KERNEL(Device::KUNLUN, OpType::LeakyRelu, LeakyReluXdnn, "LeakyRelu_xdnn_KUNLUN");
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
"Sigmoid_xdnn_KUNLUN");
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");