add reduce_mean and gather on kunlun (#169)

* add reduce_mean and gather

* fix format

* fix gather

* fix

* fix xpu, add where operation, fix element-wise operation

* fix format

---------

Co-authored-by: wanghailu <wanghailu0717@163.com>
Co-authored-by: wanghailu <wanghailu@qiyuanlab.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
Hardy 2023-11-10 17:52:09 +08:00 committed by GitHub
parent d3e7543291
commit 1ea450882b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 196 additions and 16 deletions

View File

@ -15,6 +15,12 @@ class AddXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_add<float>( auto ret = baidu::xpu::api::broadcast_add<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -35,6 +41,12 @@ class SubXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_sub<float>( auto ret = baidu::xpu::api::broadcast_sub<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -55,6 +67,12 @@ class MulXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_mul<float>( auto ret = baidu::xpu::api::broadcast_mul<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -75,6 +93,12 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_div<float>( auto ret = baidu::xpu::api::broadcast_div<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -95,6 +119,13 @@ class PowXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_pow<float>( auto ret = baidu::xpu::api::broadcast_pow<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -115,6 +146,12 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_max<float>( auto ret = baidu::xpu::api::broadcast_max<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -135,6 +172,12 @@ class MinXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_min<float>( auto ret = baidu::xpu::api::broadcast_min<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -157,6 +200,12 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_equal<float>( auto ret = baidu::xpu::api::broadcast_equal<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -181,6 +230,12 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_greater_equal<float>( auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -205,6 +260,12 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_greater_than<float>( auto ret = baidu::xpu::api::broadcast_greater_than<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -229,6 +290,12 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_less_equal<float>( auto ret = baidu::xpu::api::broadcast_less_equal<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -253,6 +320,12 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_less_than<float>( auto ret = baidu::xpu::api::broadcast_less_than<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -275,6 +348,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::broadcast_floordiv<float>( auto ret = baidu::xpu::api::broadcast_floordiv<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -317,6 +396,12 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::logical_and<bool>( auto ret = baidu::xpu::api::logical_and<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)bData, context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
(bool *)wsData, len); (bool *)wsData, len);
@ -341,6 +426,12 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::logical_or<bool>( auto ret = baidu::xpu::api::logical_or<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)bData, context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
(bool *)wsData, len); (bool *)wsData, len);
@ -365,6 +456,12 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (bDim.size() == 0) {
bDim.push_back(1);
}
auto ret = baidu::xpu::api::logical_xor<bool>( auto ret = baidu::xpu::api::logical_xor<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)bData, context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
(bool *)wsData, len); (bool *)wsData, len);

View File

@ -0,0 +1,29 @@
#include "operators/gather.h"
#include "kunlun/kunlun_kernel_without_config.h"
#include "kunlun/kunlun_runtime.h"
namespace infini {
class GatherXdnn : public KUNLUNKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<GatherObj>(_op);
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto shape = op->getInputs(0)->getDims();
auto index = op->getInputs(1)->getDims();
auto axis = op->getAxis();
auto ret = baidu::xpu::api::gather<float, int>(
context->KUNLUNHandle(), (float *)aData, (int *)bData,
(float *)cData, shape, index.size(), axis);
assert(ret == 0);
return;
}
};
REGISTER_KERNEL(Device::KUNLUN, OpType::Gather, DataType::Float32, GatherXdnn,
"Gather_xdnn_KUNLUN_Float32");
}; // namespace infini

View File

@ -13,21 +13,16 @@ class MatmulXdnn : public KUNLUNKernelWithoutConfig {
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
bool transA = op->getTransA(); bool transA = op->getTransA();
bool transB = op->getTransB(); bool transB = op->getTransB();
if (op->getInputs(0)->getDims().size() != 2 ||
op->getInputs(1)->getDims().size() != 2) {
IT_TODO_HALT();
}
auto m = transA ? op->getInputs(0)->getDims()[1] auto b = op->getB();
: op->getInputs(0)->getDims()[0]; auto m = op->getM();
auto n = transB ? op->getInputs(1)->getDims()[0] auto n = op->getN();
: op->getInputs(1)->getDims()[1]; auto k = op->getK();
auto k = transA ? op->getInputs(0)->getDims()[0]
: op->getInputs(0)->getDims()[1];
auto ret = baidu::xpu::api::fc<float, float, float, int>( auto ret = baidu::xpu::api::fc_batched<float, float, float, float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), b, transA, transB, m, n, k, 1.0,
(float *)cData, m, n, k, transA, transB, nullptr, nullptr, nullptr); (float *)aData, m * k, (float *)bData, n * k, 0.0, (float *)cData,
m * n, nullptr, nullptr);
assert(ret == 0); assert(ret == 0);
return; return;
} }

View File

@ -0,0 +1,30 @@
#include "operators/reduce_mean.h"
#include "kunlun/kunlun_kernel_without_config.h"
#include "kunlun/kunlun_runtime.h"
namespace infini {
class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceMeanObj>(_op);
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto axes_set = op->getAxes();
std::vector<int> axes;
axes.assign(axes_set.begin(), axes_set.end());
auto shape = op->getInputs(0)->getDims();
auto ret = baidu::xpu::api::reduce_mean<float>(
context->KUNLUNHandle(), (float *)aData, (float *)cData, shape,
axes);
assert(ret == 0);
return;
}
};
REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, DataType::Float32,
ReduceMeanXdnn, "ReduceMean_xdnn_KUNLUN_Float32");
}; // namespace infini

View File

@ -0,0 +1,32 @@
#include "kunlun/kunlun_kernel_without_config.h"
#include "kunlun/kunlun_runtime.h"
#include "operators/where.h"
namespace infini {
class WhereXdnn : public KUNLUNKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op);
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const dData = (op->getOutput()->getRawDataPtr<void *>());
auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims();
auto cDim = op->getInputs(2)->getDims();
auto dDim = op->getOutput()->getDims();
auto ret = baidu::xpu::api::select<float>(
context->KUNLUNHandle(), (bool *)cData, (float *)aData,
(float *)bData, (float *)dData, cDim, aDim);
assert(ret == 0);
return;
}
};
REGISTER_KERNEL(Device::KUNLUN, OpType::Where, DataType::Float32, WhereXdnn,
"Where_xdnn_KUNLUN_Float32");
}; // namespace infini

View File

@ -22,9 +22,6 @@ class SplitXdnn : public KUNLUNKernelWithoutConfig {
std::vector<int> splitList; std::vector<int> splitList;
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
auto dim = op->getOutput(i)->getDims(); auto dim = op->getOutput(i)->getDims();
if (dim.size() != 4) {
IT_TODO_HALT();
}
splitList.push_back(dim[axis]); splitList.push_back(dim[axis]);
} }