forked from jiuyuan/InfiniTensor
add reduce_mean and gather on kunlun (#169)
* add reduce_mean and gather * fix format * fix gather * fix * fix xpu, add where operation, fix element-wise operation * fix format --------- Co-authored-by: wanghailu <wanghailu0717@163.com> Co-authored-by: wanghailu <wanghailu@qiyuanlab.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
d3e7543291
commit
1ea450882b
|
@ -15,6 +15,12 @@ class AddXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_add<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -35,6 +41,12 @@ class SubXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_sub<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -55,6 +67,12 @@ class MulXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_mul<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -75,6 +93,12 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -95,6 +119,13 @@ class PowXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
|
||||
auto ret = baidu::xpu::api::broadcast_pow<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -115,6 +146,12 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_max<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -135,6 +172,12 @@ class MinXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_min<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -157,6 +200,12 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -181,6 +230,12 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -205,6 +260,12 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_greater_than<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -229,6 +290,12 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_less_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -253,6 +320,12 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_less_than<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -275,6 +348,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::broadcast_floordiv<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -317,6 +396,12 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::logical_and<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
|
@ -341,6 +426,12 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::logical_or<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
|
@ -365,6 +456,12 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
auto ret = baidu::xpu::api::logical_xor<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#include "operators/gather.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class GatherXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GatherObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto shape = op->getInputs(0)->getDims();
|
||||
auto index = op->getInputs(1)->getDims();
|
||||
auto axis = op->getAxis();
|
||||
auto ret = baidu::xpu::api::gather<float, int>(
|
||||
context->KUNLUNHandle(), (float *)aData, (int *)bData,
|
||||
(float *)cData, shape, index.size(), axis);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Gather, DataType::Float32, GatherXdnn,
|
||||
"Gather_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -13,21 +13,16 @@ class MatmulXdnn : public KUNLUNKernelWithoutConfig {
|
|||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
bool transA = op->getTransA();
|
||||
bool transB = op->getTransB();
|
||||
if (op->getInputs(0)->getDims().size() != 2 ||
|
||||
op->getInputs(1)->getDims().size() != 2) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
auto m = transA ? op->getInputs(0)->getDims()[1]
|
||||
: op->getInputs(0)->getDims()[0];
|
||||
auto n = transB ? op->getInputs(1)->getDims()[0]
|
||||
: op->getInputs(1)->getDims()[1];
|
||||
auto k = transA ? op->getInputs(0)->getDims()[0]
|
||||
: op->getInputs(0)->getDims()[1];
|
||||
auto b = op->getB();
|
||||
auto m = op->getM();
|
||||
auto n = op->getN();
|
||||
auto k = op->getK();
|
||||
|
||||
auto ret = baidu::xpu::api::fc<float, float, float, int>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, m, n, k, transA, transB, nullptr, nullptr, nullptr);
|
||||
auto ret = baidu::xpu::api::fc_batched<float, float, float, float>(
|
||||
context->KUNLUNHandle(), b, transA, transB, m, n, k, 1.0,
|
||||
(float *)aData, m * k, (float *)bData, n * k, 0.0, (float *)cData,
|
||||
m * n, nullptr, nullptr);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#include "operators/reduce_mean.h"
|
||||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ReduceMeanXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceMeanObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto axes_set = op->getAxes();
|
||||
std::vector<int> axes;
|
||||
axes.assign(axes_set.begin(), axes_set.end());
|
||||
auto shape = op->getInputs(0)->getDims();
|
||||
|
||||
auto ret = baidu::xpu::api::reduce_mean<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)cData, shape,
|
||||
axes);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::ReduceMean, DataType::Float32,
|
||||
ReduceMeanXdnn, "ReduceMean_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,32 @@
|
|||
#include "kunlun/kunlun_kernel_without_config.h"
|
||||
#include "kunlun/kunlun_runtime.h"
|
||||
#include "operators/where.h"
|
||||
|
||||
namespace infini {
|
||||
class WhereXdnn : public KUNLUNKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<WhereObj>(_op);
|
||||
auto context = dynamic_cast<const KUNLUNRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const dData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
auto cDim = op->getInputs(2)->getDims();
|
||||
auto dDim = op->getOutput()->getDims();
|
||||
|
||||
auto ret = baidu::xpu::api::select<float>(
|
||||
context->KUNLUNHandle(), (bool *)cData, (float *)aData,
|
||||
(float *)bData, (float *)dData, cDim, aDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Where, DataType::Float32, WhereXdnn,
|
||||
"Where_xdnn_KUNLUN_Float32");
|
||||
}; // namespace infini
|
|
@ -22,9 +22,6 @@ class SplitXdnn : public KUNLUNKernelWithoutConfig {
|
|||
std::vector<int> splitList;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
auto dim = op->getOutput(i)->getDims();
|
||||
if (dim.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
splitList.push_back(dim[axis]);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue