From a889527aa5202085483e166fdfcd7621a9ff6fe8 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Sat, 11 May 2024 16:24:42 +0800 Subject: [PATCH] add kunlun layernorm --- src/kernels/kunlun/layer_norm.cc | 45 ++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 src/kernels/kunlun/layer_norm.cc diff --git a/src/kernels/kunlun/layer_norm.cc b/src/kernels/kunlun/layer_norm.cc new file mode 100644 index 00000000..adb03639 --- /dev/null +++ b/src/kernels/kunlun/layer_norm.cc @@ -0,0 +1,45 @@ +#include "operators/layer_norm.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class LayerNormXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = static_cast(_context); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const scaleData = (op->getInputs(1)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + float eps = op->getEps(); + // int axis = op->getAxis(); + + const auto &opInputShape = op->getInputs(0)->getDims(); + const auto &opOutputShape = op->getOutput()->getDims(); + IT_ASSERT(opInputShape.size() == 2); + + int ret; + if (op->numInputs() == 3) { + // with bias + void *const biasData = op->getInputs(2)->getRawDataPtr(); + ret = xdnn::layer_norm( + context->KUNLUNHandle(), (float const *)inputData, + (float *)outputData, opInputShape[0], opInputShape[1], eps, + (float *)scaleData, (float *)biasData, nullptr, nullptr); + } else { + // without bias + ret = xdnn::layer_norm( + context->KUNLUNHandle(), (float const *)inputData, + (float *)outputData, opInputShape[0], opInputShape[1], eps, + (float *)scaleData, nullptr, nullptr, nullptr); + } + assert(ret == 0); + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::LayerNormalization, LayerNormXdnn, + "LayerNorm_xdnn_KUNLUN"); + +}; // namespace infini \ No newline at end of file