forked from jiuyuan/InfiniTensor
remove dimension limit of elementwise operators on xpu (#168)
This commit is contained in:
parent
07ef587c65
commit
7f5188bedd
|
@ -15,8 +15,6 @@ class AddXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_add<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -37,8 +35,6 @@ class SubXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_sub<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -59,8 +55,6 @@ class MulXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_mul<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -81,8 +75,6 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -103,8 +95,6 @@ class PowXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_pow<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -125,8 +115,6 @@ class MaxXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_max<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -147,8 +135,6 @@ class MinXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_min<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim);
|
||||
|
@ -171,8 +157,6 @@ class EqualXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -197,8 +181,6 @@ class GreaterEqualXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_greater_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -223,8 +205,6 @@ class GreaterThanXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_greater_than<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -249,8 +229,6 @@ class LessEqualXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_less_equal<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -275,8 +253,6 @@ class LessThanXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_less_than<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(bool *)wsData, aDim, bDim);
|
||||
|
@ -296,18 +272,12 @@ class FloorDivXdnn : public KUNLUNKernelWithoutConfig {
|
|||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
size_t len = op->getOutput()->size();
|
||||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_floordiv<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)wsData, aDim, bDim);
|
||||
ret = baidu::xpu::api::cast<int, float>(
|
||||
context->KUNLUNHandle(), (int *)wsData, (float *)cData, len);
|
||||
(float *)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
}
|
||||
|
@ -325,9 +295,6 @@ class MSELossXdnn : public KUNLUNKernelWithoutConfig {
|
|||
size_t len = op->getOutput()->size();
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
auto ret = baidu::xpu::api::mse_loss<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, len);
|
||||
|
@ -350,8 +317,6 @@ class AndXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_and<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
|
@ -376,8 +341,6 @@ class OrXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_or<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
|
@ -402,8 +365,6 @@ class XorXdnn : public KUNLUNKernelWithoutConfig {
|
|||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_xor<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)bData,
|
||||
(bool *)wsData, len);
|
||||
|
@ -426,8 +387,6 @@ class NotXdnn : public KUNLUNKernelWithoutConfig {
|
|||
KUNLUNPtr wsData = context->getWorkspace(len);
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
if (aDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::logical_not<bool>(
|
||||
context->KUNLUNHandle(), (bool *)aData, (bool *)wsData, len);
|
||||
ret = baidu::xpu::api::cast<bool, float>(
|
||||
|
|
Loading…
Reference in New Issue