remove dimension limit of elementwise operators on xpu (#168)

This commit is contained in:
Haojie Wang 2023-10-25 14:38:47 +08:00 committed by GitHub
parent 07ef587c65
commit 7f5188bedd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 42 deletions

View File

@ -15,8 +15,6 @@ class AddXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_add<float>( auto ret = baidu::xpu::api::broadcast_add<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -37,8 +35,6 @@ class SubXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_sub<float>( auto ret = baidu::xpu::api::broadcast_sub<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -59,8 +55,6 @@ class MulXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_mul<float>( auto ret = baidu::xpu::api::broadcast_mul<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -81,8 +75,6 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_div<float>( auto ret = baidu::xpu::api::broadcast_div<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -103,8 +95,6 @@ class PowXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_pow<float>( auto ret = baidu::xpu::api::broadcast_pow<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -125,8 +115,6 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_max<float>( auto ret = baidu::xpu::api::broadcast_max<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -147,8 +135,6 @@ class MinXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_min<float>( auto ret = baidu::xpu::api::broadcast_min<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim); (float *)cData, aDim, bDim);
@ -171,8 +157,6 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_equal<float>( auto ret = baidu::xpu::api::broadcast_equal<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -197,8 +181,6 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_greater_equal<float>( auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -223,8 +205,6 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_greater_than<float>( auto ret = baidu::xpu::api::broadcast_greater_than<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -249,8 +229,6 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_less_equal<float>( auto ret = baidu::xpu::api::broadcast_less_equal<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -275,8 +253,6 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_less_than<float>( auto ret = baidu::xpu::api::broadcast_less_than<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(bool *)wsData, aDim, bDim); (bool *)wsData, aDim, bDim);
@ -296,18 +272,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
size_t len = op->getOutput()->size();
KUNLUNPtr wsData = context->getWorkspace(len);
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::broadcast_floordiv<float>( auto ret = baidu::xpu::api::broadcast_floordiv<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)wsData, aDim, bDim); (float *)cData, aDim, bDim);
ret = baidu::xpu::api::cast<int, float>(
context->KUNLUNHandle(), (int *)wsData, (float *)cData, len);
assert(ret == 0); assert(ret == 0);
return; return;
} }
@ -325,9 +295,6 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig {
size_t len = op->getOutput()->size(); size_t len = op->getOutput()->size();
auto dim = op->getInputs(0)->getDims(); auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::mse_loss<float>( auto ret = baidu::xpu::api::mse_loss<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData, context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, len); (float *)cData, len);
@ -350,8 +317,6 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::logical_and<bool>( auto ret = baidu::xpu::api::logical_and<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)bData, context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
(bool *)wsData, len); (bool *)wsData, len);
@ -376,8 +341,6 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::logical_or<bool>( auto ret = baidu::xpu::api::logical_or<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)bData, context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
(bool *)wsData, len); (bool *)wsData, len);
@ -402,8 +365,6 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
if (aDim.size() != 4 || bDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::logical_xor<bool>( auto ret = baidu::xpu::api::logical_xor<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)bData, context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
(bool *)wsData, len); (bool *)wsData, len);
@ -426,8 +387,6 @@ class NotXdnn : public KUNLUNKernelWithoutConfig {
KUNLUNPtr wsData = context->getWorkspace(len); KUNLUNPtr wsData = context->getWorkspace(len);
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
if (aDim.size() != 4)
IT_TODO_HALT();
auto ret = baidu::xpu::api::logical_not<bool>( auto ret = baidu::xpu::api::logical_not<bool>(
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len); context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len);
ret = baidu::xpu::api::cast<bool, float>( ret = baidu::xpu::api::cast<bool, float>(