forked from jiuyuan/InfiniTensor
feat: kunlun 上添加LeakyRelu,修复BatchNorm中维度为4的限制,跑通bgan
This commit is contained in:
parent
23b1612192
commit
36baae7615
|
@ -19,13 +19,17 @@ class BatchNormXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto dims = op->getInputs(0)->getDims();
|
auto dims = op->getInputs(0)->getDims();
|
||||||
|
|
||||||
if (dims.size() != 4)
|
int n, c, h, w;
|
||||||
IT_TODO_HALT();
|
if (dims.size() != 4){
|
||||||
|
h = 1;
|
||||||
|
w = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
w = dims[3];
|
||||||
|
h = dims[2];
|
||||||
|
c = dims[1];
|
||||||
|
n = dims[0];
|
||||||
|
|
||||||
int w = dims[3];
|
|
||||||
int h = dims[2];
|
|
||||||
int c = dims[1];
|
|
||||||
int n = dims[0];
|
|
||||||
auto ret = xdnn::batch_norm_infer<float>(
|
auto ret = xdnn::batch_norm_infer<float>(
|
||||||
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
context->KUNLUNHandle(), (float *)input, (float *)output, n, c, h,
|
||||||
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
w, op->getEps(), (float *)scale, (float *)bias, (float *)mean,
|
||||||
|
|
|
@ -21,6 +21,26 @@ class ReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LeakyReluXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<LeakyReluObj>(_op);
|
||||||
|
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||||
|
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
auto len = op->getInputs(0)->size();
|
||||||
|
auto alpha = op->getAlpha();
|
||||||
|
|
||||||
|
auto ret = xdnn::leaky_relu<float>(context->KUNLUNHandle(),
|
||||||
|
(float *const)aData, (float *)cData,
|
||||||
|
len, alpha);
|
||||||
|
assert(ret == 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
class SigmoidXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
|
@ -552,6 +572,7 @@ class ATanhXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Relu, ReluXdnn, "Relu_xdnn_KUNLUN");
|
||||||
|
REGISTER_KERNEL(Device::KUNLUN, OpType::LeakyRelu, LeakyReluXdnn, "LeakyRelu_xdnn_KUNLUN");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Sigmoid, SigmoidXdnn,
|
||||||
"Sigmoid_xdnn_KUNLUN");
|
"Sigmoid_xdnn_KUNLUN");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Tanh, TanhXdnn, "Tanh_xdnn_KUNLUN");
|
||||||
|
|
Loading…
Reference in New Issue