forked from jiuyuan/InfiniTensor
remove dimension limit of elementwise operators on xpu (#168)
This commit is contained in:
parent
07ef587c65
commit
7f5188bedd
|
@ -15,8 +15,6 @@ class AddXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_add<float>(
|
auto ret = baidu::xpu::api::broadcast_add<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -37,8 +35,6 @@ class SubXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_sub<float>(
|
auto ret = baidu::xpu::api::broadcast_sub<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -59,8 +55,6 @@ class MulXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_mul<float>(
|
auto ret = baidu::xpu::api::broadcast_mul<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -81,8 +75,6 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_div<float>(
|
auto ret = baidu::xpu::api::broadcast_div<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -103,8 +95,6 @@ class PowXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_pow<float>(
|
auto ret = baidu::xpu::api::broadcast_pow<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -125,8 +115,6 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_max<float>(
|
auto ret = baidu::xpu::api::broadcast_max<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -147,8 +135,6 @@ class MinXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_min<float>(
|
auto ret = baidu::xpu::api::broadcast_min<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
|
@ -171,8 +157,6 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_equal<float>(
|
auto ret = baidu::xpu::api::broadcast_equal<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim);
|
||||||
|
@ -197,8 +181,6 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
|
auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim);
|
||||||
|
@ -223,8 +205,6 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_greater_than<float>(
|
auto ret = baidu::xpu::api::broadcast_greater_than<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim);
|
||||||
|
@ -249,8 +229,6 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_less_equal<float>(
|
auto ret = baidu::xpu::api::broadcast_less_equal<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim);
|
||||||
|
@ -275,8 +253,6 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_less_than<float>(
|
auto ret = baidu::xpu::api::broadcast_less_than<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(bool *)wsData, aDim, bDim);
|
(bool *)wsData, aDim, bDim);
|
||||||
|
@ -296,18 +272,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
size_t len = op->getOutput()->size();
|
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::broadcast_floordiv<float>(
|
auto ret = baidu::xpu::api::broadcast_floordiv<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)wsData, aDim, bDim);
|
(float *)cData, aDim, bDim);
|
||||||
ret = baidu::xpu::api::cast<int, float>(
|
|
||||||
context->KUNLUNHandle(), (int *)wsData, (float *)cData, len);
|
|
||||||
assert(ret == 0);
|
assert(ret == 0);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -325,9 +295,6 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
size_t len = op->getOutput()->size();
|
size_t len = op->getOutput()->size();
|
||||||
|
|
||||||
auto dim = op->getInputs(0)->getDims();
|
auto dim = op->getInputs(0)->getDims();
|
||||||
if (dim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
|
|
||||||
auto ret = baidu::xpu::api::mse_loss<float>(
|
auto ret = baidu::xpu::api::mse_loss<float>(
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
(float *)cData, len);
|
(float *)cData, len);
|
||||||
|
@ -350,8 +317,6 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::logical_and<bool>(
|
auto ret = baidu::xpu::api::logical_and<bool>(
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||||
(bool *)wsData, len);
|
(bool *)wsData, len);
|
||||||
|
@ -376,8 +341,6 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::logical_or<bool>(
|
auto ret = baidu::xpu::api::logical_or<bool>(
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||||
(bool *)wsData, len);
|
(bool *)wsData, len);
|
||||||
|
@ -402,8 +365,6 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
if (aDim.size() != 4 || bDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::logical_xor<bool>(
|
auto ret = baidu::xpu::api::logical_xor<bool>(
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||||
(bool *)wsData, len);
|
(bool *)wsData, len);
|
||||||
|
@ -426,8 +387,6 @@ class NotXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||||
|
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
if (aDim.size() != 4)
|
|
||||||
IT_TODO_HALT();
|
|
||||||
auto ret = baidu::xpu::api::logical_not<bool>(
|
auto ret = baidu::xpu::api::logical_not<bool>(
|
||||||
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len);
|
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len);
|
||||||
ret = baidu::xpu::api::cast<bool, float>(
|
ret = baidu::xpu::api::cast<bool, float>(
|
||||||
|
|
Loading…
Reference in New Issue