forked from jiuyuan/InfiniTensor
feat: kunlun 上添加LeakyRelu,修复BatchNorm中维度为4的限制,跑通bgan
This commit is contained in:
parent
23b1612192
commit
36baae7615
|
@ -19,13 +19,17 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
if (dims.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
int n, c, h, w;
|
||||
if (dims.size() != 4){
|
||||
h = 1;
|
||||
w = 1;
|
||||
}
|
||||
|
||||
w = dims[3];
|
||||
h = dims[2];
|
||||
c = dims[1];
|
||||
n = dims[0];
|
||||
|
||||
int w = dims[3];
|
||||
int h = dims[2];
|
||||
int c = dims[1];
|
||||
int n = dims[0];
|
||||
auto ret = xdnn::batch_norm_infer<float>(
|
||||
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
||||
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
||||
|
|
|
@ -21,6 +21,26 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class LeakyReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LeakyReluObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto len = op->getInputs(0)->size();
|
||||
auto alpha = op->getAlpha();
|
||||
|
||||
auto ret = xdnn::leaky_relu<float>(context->KUNLUNHandle(),
|
||||
(float *const)aData, (float *)cData,
|
||||
len, alpha);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
|
@ -552,6 +572,7 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig {
|
|||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::LeakyRelu, LeakyReluXdnn, "LeakyRelu_xdnn_KUNLUN");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
|
||||
"Sigmoid_xdnn_KUNLUN");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");
|
||||
|
|
Loading…
Reference in New Issue