diff --git a/include/utils/operator_utils.h b/include/utils/operator_utils.h index 1b3a1eb2..b0871c0b 100644 --- a/include/utils/operator_utils.h +++ b/include/utils/operator_utils.h @@ -13,6 +13,11 @@ Shape infer_broadcast(const Shape &A, const Shape &B); int get_real_axis(const int &axis, const int &rank); // Check if tensor B is unidirectional broadcastable to tensor A bool is_unidirectional_broadcasting(const Shape &A, const Shape &B); +// Locate the index with size from Shape +Shape locate_index(size_t inputN, const Shape &shape); +// Delocate the ShapeIndex from Shape with broadcast +size_t delocate_index(const Shape &shapeIndex, const Shape &shape, + const Shape &stride); // Convert KernelAttrs to a string representation std::string get_kernel_attrs_str(const KernelAttrs &kernelAttrs); } // namespace infini diff --git a/src/kernels/cpu/element_wise.cc b/src/kernels/cpu/element_wise.cc index 8d225779..ff03350c 100644 --- a/src/kernels/cpu/element_wise.cc +++ b/src/kernels/cpu/element_wise.cc @@ -1,5 +1,6 @@ #include "operators/element_wise.h" #include "core/kernel.h" +#include "utils/operator_utils.h" namespace infini { template class NativeElementWise : public CpuKernelWithoutConfig { @@ -11,37 +12,34 @@ template class NativeElementWise : public CpuKernelWithoutConfig { T *inptr1 = op->getInputs(1)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); - int a[4] = {1, 1, 1, 1}; - int b[4] = {1, 1, 1, 1}; - int c[4] = {1, 1, 1, 1}; - auto a_input = op->getInputs(0)->getDims(); - auto b_input = op->getInputs(1)->getDims(); - auto c_output = op->getOutput()->getDims(); - std::copy(a_input.begin(), a_input.end(), a + (4 - a_input.size())); - std::copy(b_input.begin(), b_input.end(), b + (4 - b_input.size())); - std::copy(c_output.begin(), c_output.end(), c + (4 - c_output.size())); + auto shapeA = op->getInputs(0)->getDims(); + auto shapeB = op->getInputs(1)->getDims(); + auto shapeC = op->getOutput()->getDims(); + auto rank = op->getOutput()->getRank(); + Shape a(rank, 1); + Shape b(rank, 1); + std::copy(shapeA.begin(), shapeA.end(), + a.begin() + (rank - shapeA.size())); + std::copy(shapeB.begin(), shapeB.end(), + b.begin() + (rank - shapeB.size())); + auto getStride = [&](const Shape &shape) { + int p = 1; + Shape stride(rank); + for (auto i = rank; i > 0; --i) { + stride[i - 1] = p; + p = p * shape[i - 1]; + } + return stride; + }; + Shape strideA = getStride(a); + Shape strideB = getStride(b); auto n = op->getOutput()->size(); for (size_t i = 0; i < n; ++i) { - int c0_index = i / (c[1] * c[2] * c[3]); - int c1_index = (i % (c[1] * c[2] * c[3])) / (c[2] * c[3]); - int c2_index = ((i % (c[1] * c[2] * c[3])) % (c[2] * c[3])) / c[3]; - int c3_index = ((i % (c[1] * c[2] * c[3])) % (c[2] * c[3])) % c[3]; - - int a0_index = c0_index % a[0]; - int a1_index = c1_index % a[1]; - int a2_index = c2_index % a[2]; - int a3_index = c3_index % a[3]; - - int b0_index = c0_index % b[0]; - int b1_index = c1_index % b[1]; - int b2_index = c2_index % b[2]; - int b3_index = c3_index % b[3]; - outptr[i] = doCompute( - inptr0[a0_index * a[1] * a[2] * a[3] + a1_index * a[2] * a[3] + - a2_index * a[3] + a3_index], - inptr1[b0_index * b[1] * b[2] * b[3] + b1_index * b[2] * b[3] + - b2_index * b[3] + b3_index]); + auto shapeIndexC = locate_index(i, shapeC); + auto indexA = delocate_index(shapeIndexC, a, strideA); + auto indexB = delocate_index(shapeIndexC, b, strideB); + outptr[i] = doCompute(inptr0[indexA], inptr1[indexB]); } } }; diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index 76a1d91f..6687a8fd 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -66,6 +66,31 @@ bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) { return true; } +Shape locate_index(size_t inputN, const Shape &shape) { + Shape ans(shape.size()); + auto i = ans.rbegin(); + auto j = shape.rbegin(), ej = shape.rend(); + while (j != ej) { + auto div = std::div(inputN, *j++); + *i++ = div.rem; + inputN = div.quot; + } + return ans; +} + +size_t delocate_index(const Shape &shapeIndex, const Shape &shape, + const Shape &stride) { + size_t ans = 0; + Shape index(shapeIndex.size()); + IT_ASSERT(shapeIndex.size() == shape.size()); + IT_ASSERT(shape.size() == stride.size()); + for (size_t i = 0; i < shape.size(); ++i) { + index[i] = shapeIndex[i] % shape[i]; + ans += index[i] * stride[i]; + } + return ans; +} + std::string device_to_str(Device device) { std::string deviceStr; switch (device) { diff --git a/test/kernels/nativecpu/test_nativecpu_elementwise.cc b/test/kernels/nativecpu/test_nativecpu_elementwise.cc new file mode 100644 index 00000000..c6ef1911 --- /dev/null +++ b/test/kernels/nativecpu/test_nativecpu_elementwise.cc @@ -0,0 +1,44 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +template +void testElementWiseNativeCpu( + const std::function &generator1, + const std::function &generator2, + const Shape &shape1, const Shape &shape2, const ExpectOutput &ansVec) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto t1 = g->addTensor(shape1, DataType::Float32); + auto t2 = g->addTensor(shape2, DataType::Float32); + + auto op = g->addOp(t1, t2, nullptr); + g->dataMalloc(); + t1->setData(generator1); + t2->setData(generator2); + + runtime->run(g); + EXPECT_TRUE(op->getOutput()->equalData(ansVec)); +} + +TEST(ElementWise, NativeCpu) { + testElementWiseNativeCpu( + IncrementalGenerator(), IncrementalGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 1, 2, 4, 5, 6, 6, 7, 8, 10, 11, 12}); + testElementWiseNativeCpu( + IncrementalGenerator(), IncrementalGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 0, 0, 3, 4, 5, 0, 0, 0, 9, 10, 11}); + testElementWiseNativeCpu( + IncrementalGenerator(), IncrementalGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 1, 2, 2, 3, 4, 6, 7, 8, 8, 9, 10}); + testElementWiseNativeCpu( + IncrementalGenerator(), OneGenerator(), Shape{1, 2, 2, 3, 1}, + Shape{2, 1, 1}, ExpectOutput{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); +} + +} // namespace infini