From 7f5188bedd274ef6e9f46170b4770b0b10919385 Mon Sep 17 00:00:00 2001 From: Haojie Wang Date: Wed, 25 Oct 2023 14:38:47 +0800 Subject: [PATCH] remove dimension limit of elementwise operators on xpu (#168) --- src/kernels/kunlun/element_wise.cc | 43 +----------------------------- 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/src/kernels/kunlun/element_wise.cc b/src/kernels/kunlun/element_wise.cc index 03ce74b1..f71c11bc 100644 --- a/src/kernels/kunlun/element_wise.cc +++ b/src/kernels/kunlun/element_wise.cc @@ -15,8 +15,6 @@ class AddXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_add( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -37,8 +35,6 @@ class SubXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_sub( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -59,8 +55,6 @@ class MulXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_mul( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -81,8 +75,6 @@ class DivXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_div( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -103,8 +95,6 @@ class PowXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_pow( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -125,8 +115,6 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_max( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -147,8 +135,6 @@ class MinXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_min( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, aDim, bDim); @@ -171,8 +157,6 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -197,8 +181,6 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_greater_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -223,8 +205,6 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_greater_than( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -249,8 +229,6 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_less_equal( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -275,8 +253,6 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_less_than( context->KUNLUNHandle(), (float *)aData, (float *)bData, (bool *)wsData, aDim, bDim); @@ -296,18 +272,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig { void *const aData = (op->getInputs(0)->getRawDataPtr()); void *const bData = (op->getInputs(1)->getRawDataPtr()); void *const cData = (op->getOutput()->getRawDataPtr()); - size_t len = op->getOutput()->size(); - KUNLUNPtr wsData = context->getWorkspace(len); auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::broadcast_floordiv( context->KUNLUNHandle(), (float *)aData, (float *)bData, - (float *)wsData, aDim, bDim); - ret = baidu::xpu::api::cast( - context->KUNLUNHandle(), (int *)wsData, (float *)cData, len); + (float *)cData, aDim, bDim); assert(ret == 0); return; } @@ -325,9 +295,6 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig { size_t len = op->getOutput()->size(); auto dim = op->getInputs(0)->getDims(); - if (dim.size() != 4) - IT_TODO_HALT(); - auto ret = baidu::xpu::api::mse_loss( context->KUNLUNHandle(), (float *)aData, (float *)bData, (float *)cData, len); @@ -350,8 +317,6 @@ class AndXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::logical_and( context->KUNLUNHandle(), (bool *)aData, (bool *)bData, (bool *)wsData, len); @@ -376,8 +341,6 @@ class OrXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::logical_or( context->KUNLUNHandle(), (bool *)aData, (bool *)bData, (bool *)wsData, len); @@ -402,8 +365,6 @@ class XorXdnn : public KUNLUNKernelWithoutConfig { auto aDim = op->getInputs(0)->getDims(); auto bDim = op->getInputs(1)->getDims(); - if (aDim.size() != 4 || bDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::logical_xor( context->KUNLUNHandle(), (bool *)aData, (bool *)bData, (bool *)wsData, len); @@ -426,8 +387,6 @@ class NotXdnn : public KUNLUNKernelWithoutConfig { KUNLUNPtr wsData = context->getWorkspace(len); auto aDim = op->getInputs(0)->getDims(); - if (aDim.size() != 4) - IT_TODO_HALT(); auto ret = baidu::xpu::api::logical_not( context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len); ret = baidu::xpu::api::cast(