diff --git a/src/kernels/kunlun/element_wise.cc b/src/kernels/kunlun/element_wise.cc index f71c11bc..3370eb1a 100644 --- a/src/kernels/kunlun/element_wise.cc +++ b/src/kernels/kunlun/element_wise.cc @@ -15,6 +15,12 @@ class AddXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_add( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -35,6 +41,12 @@ class SubXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_sub( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -55,6 +67,12 @@ class MulXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_mul( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -75,6 +93,12 @@ class DivXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_div( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -95,6 +119,13 @@ class PowXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } + auto ret = baidu::xpu::api::broadcast_pow( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -115,6 +146,12 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_max( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -135,6 +172,12 @@ class MinXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_min( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -157,6 +200,12 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -181,6 +230,12 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_greater_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -205,6 +260,12 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_greater_than( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -229,6 +290,12 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_less_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -253,6 +320,12 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_less_than( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -275,6 +348,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::broadcast_floordiv( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -317,6 +396,12 @@ class AndXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::logical_and( context->KUNLUNHandle(), (bool *)aData, (bool *)bData, (bool *)wsData, len); @@ -341,6 +426,12 @@ class OrXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::logical_or( context->KUNLUNHandle(), (bool *)aData, (bool *)bData, (bool *)wsData, len); @@ -365,6 +456,12 @@ class XorXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); + if (aDim.size() == 0) { + aDim.push_back(1); + } + if (bDim.size() == 0) { + bDim.push_back(1); + } auto ret = baidu::xpu::api::logical_xor( context->KUNLUNHandle(), (bool *)aData, (bool *)bData, (bool *)wsData, len); diff --git a/src/kernels/kunlun/gather.cc b/src/kernels/kunlun/gather.cc new file mode 100644 index 00000000..f94d24fa --- /dev/null +++ b/src/kernels/kunlun/gather.cc @@ -0,0 +1,29 @@ +#include "operators/gather.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class GatherXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto shape = op->getInputs(0)->getDims(); + auto index = op->getInputs(1)->getDims(); + auto axis = op->getAxis(); + auto ret = baidu::xpu::api::gather( + context->KUNLUNHandle(), (float *)aData, (int *)bData, + (float *)cData, shape, index.size(), axis); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Gather, DataType::Float32, GatherXdnn, + "Gather_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/matmul.cc b/src/kernels/kunlun/matmul.cc index 91240ce3..8506e812 100644 --- a/src/kernels/kunlun/matmul.cc +++ b/src/kernels/kunlun/matmul.cc @@ -13,21 +13,16 @@ class MatmulXdnn : public KUNLUNKernelWithoutConfig { void *const cData = (op->getOutput()->getRawDataPtr()); bool transA = op->getTransA(); bool transB = op->getTransB(); - if (op->getInputs(0)->getDims().size() != 2 || - op->getInputs(1)->getDims().size() != 2) { - IT_TODO_HALT(); - } - auto m = transA ? op->getInputs(0)->getDims()[1] - : op->getInputs(0)->getDims()[0]; - auto n = transB ? op->getInputs(1)->getDims()[0] - : op->getInputs(1)->getDims()[1]; - auto k = transA ? op->getInputs(0)->getDims()[0] - : op->getInputs(0)->getDims()[1]; + auto b = op->getB(); + auto m = op->getM(); + auto n = op->getN(); + auto k = op->getK(); - auto ret = baidu::xpu::api::fc( - context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)cData, m, n, k, transA, transB, nullptr, nullptr, nullptr); + auto ret = baidu::xpu::api::fc_batched( + context->KUNLUNHandle(), b, transA, transB, m, n, k, 1.0, + (float *)aData, m * k, (float *)bData, n * k, 0.0, (float *)cData, + m * n, nullptr, nullptr); assert(ret == 0); return; } diff --git a/src/kernels/kunlun/reduce_mean.cc b/src/kernels/kunlun/reduce_mean.cc new file mode 100644 index 00000000..08a01fd6 --- /dev/null +++ b/src/kernels/kunlun/reduce_mean.cc @@ -0,0 +1,30 @@ +#include "operators/reduce_mean.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto axes_set = op->getAxes(); + std::vector axes; + axes.assign(axes_set.begin(), axes_set.end()); + auto shape = op->getInputs(0)->getDims(); + + auto ret = baidu::xpu::api::reduce_mean( + context->KUNLUNHandle(), (float *)aData, (float *)cData, shape, + axes); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, DataType::Float32, + ReduceMeanXdnn, "ReduceMean_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/select.cc b/src/kernels/kunlun/select.cc new file mode 100644 index 00000000..d6318e46 --- /dev/null +++ b/src/kernels/kunlun/select.cc @@ -0,0 +1,32 @@ +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/where.h" + +namespace infini { +class WhereXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getInputs(2)->getRawDataPtr()); + void *const dData = (op->getOutput()->getRawDataPtr()); + + auto aDim = op->getInputs(0)->getDims(); + auto bDim = op->getInputs(1)->getDims(); + auto cDim = op->getInputs(2)->getDims(); + auto dDim = op->getOutput()->getDims(); + + auto ret = baidu::xpu::api::select( + context->KUNLUNHandle(), (bool *)cData, (float *)aData, + (float *)bData, (float *)dData, cDim, aDim); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Where, DataType::Float32, WhereXdnn, + "Where_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/src/kernels/kunlun/split.cc b/src/kernels/kunlun/split.cc index 301ef027..46276c85 100644 --- a/src/kernels/kunlun/split.cc +++ b/src/kernels/kunlun/split.cc @@ -22,9 +22,6 @@ class SplitXdnn : public KUNLUNKernelWithoutConfig { std::vector splitList; for (int i = 0; i < num; ++i) { auto dim = op->getOutput(i)->getDims(); - if (dim.size() != 4) { - IT_TODO_HALT(); - } splitList.push_back(dim[axis]); }