From 6ece3f4a777483cb11339f8892b08415504b337c Mon Sep 17 00:00:00 2001 From: PanZezhong1725 <141193946+PanZezhong1725@users.noreply.github.com> Date: Fri, 24 Nov 2023 09:29:58 +0800 Subject: [PATCH] Add ReduceSum op and kernel (#160) * Add reduceSum op and kernel * fix merge and format * Reduce: reuse cat macro, add doc string --------- Co-authored-by: Haojie Wang --- include/core/graph_handler.h | 2 + include/operators/{reduce_mean.h => reduce.h} | 26 ++++-- pyinfinitensor/src/pyinfinitensor/onnx.py | 49 +++++++++-- pyinfinitensor/tests/test_onnx.py | 8 ++ src/core/graph_handler.cc | 29 ++++--- src/ffi/ffi_infinitensor.cc | 19 +++-- .../cuda/{reduce_mean.cc => reduce.cc} | 24 +++++- src/operators/{reduce_mean.cc => reduce.cc} | 31 ++++--- test/core/test_graph_replace.cc | 2 +- test/kernels/cuda/test_cuda_reduce.cc | 83 +++++++++++++++++++ test/kernels/cuda/test_cuda_reduce_mean.cc | 61 -------------- .../{test_reduce_mean.cc => test_reduce.cc} | 22 +++-- 12 files changed, 235 insertions(+), 121 deletions(-) rename include/operators/{reduce_mean.h => reduce.h} (59%) rename src/kernels/cuda/{reduce_mean.cc => reduce.cc} (87%) rename src/operators/{reduce_mean.cc => reduce.cc} (67%) create mode 100644 test/kernels/cuda/test_cuda_reduce.cc delete mode 100644 test/kernels/cuda/test_cuda_reduce_mean.cc rename test/operators/{test_reduce_mean.cc => test_reduce.cc} (68%) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 4b66f11a..8c4f59bc 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -73,6 +73,8 @@ class GraphHandlerObj { Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis); Tensor reduceMean(Tensor data, Tensor reduced, const optional> &axes, bool keepdims); + Tensor reduceSum(Tensor data, Tensor reduced, + const optional> &axes, bool keepdims); Tensor slice(Tensor input, Tensor output, const vector &starts, const vector &ends, const optional> &axes, const optional> &steps); diff --git a/include/operators/reduce_mean.h b/include/operators/reduce.h similarity index 59% rename from include/operators/reduce_mean.h rename to include/operators/reduce.h index 18ef38b1..defcf9b3 100644 --- a/include/operators/reduce_mean.h +++ b/include/operators/reduce.h @@ -3,26 +3,29 @@ namespace infini { /** - * @brief Compute the mean of input tensor's elements along certain axes. + * @brief Compute the reduction of input tensor's elements along certain axes. * */ -class ReduceMeanObj : public OperatorObj { +class ReduceBaseObj : public OperatorObj { + protected: set axes; // axis to reduce bool keepDims; public: /** - * @brief Construct a new ReduceMean object. + * @brief Construct a new Reduce object. * * @param graph The computation graph that this operator belongs to. + * @param opType The operation type. Should be a Reduce operation. * @param input The input tensor. * @param output The output tensor. * @param axes Axes to reduce. * @param keepDims Keep the reduced dimensions or not. */ - ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, - const optional> &axes, bool keepDims = true); - OP_CLONE(ReduceMeanObj); + ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, Tensor output, + const optional> &axes, bool keepDims); + virtual ~ReduceBaseObj() {} + OP_CLONE(ReduceBaseObj); optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; @@ -38,4 +41,15 @@ class ReduceMeanObj : public OperatorObj { vector getOpAttrVector() const override; }; +class ReduceMeanObj : public ReduceBaseObj { + public: + ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, + const optional> &axes, bool keepDims = true); +}; + +class ReduceSumObj : public ReduceBaseObj { + public: + ReduceSumObj(GraphObj *graph, Tensor input, Tensor output, + const optional> &axes, bool keepDims = true); +}; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index d48ef52a..ad842d5b 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -604,7 +604,7 @@ class OnnxStub: ), ) elif node.op_type == "ReduceMean": - tensors[node.output[0]] = self.handler.reduce_mean( + tensors[node.output[0]] = self.handler.reduceMean( tensors[node.input[0]], tensors.get(node.output[0]), # NOTE(constroy): `axes` is an attribute until opset version 13. @@ -678,12 +678,40 @@ class OnnxStub: next((attr.i for attr in node.attribute if attr.name == "to")), ) elif node.op_type == "ReduceSum": - # ReduceSum is only implemented as allReduceSum. - assert any(attr.name == "communicator" for attr in node.attribute) - tensors[node.output[0]] = self.handler.allReduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) + if any(attr.name == "communicator" for attr in node.attribute): + # ReduceSum with communicator is treated as allReduceSum. + tensors[node.output[0]] = self.handler.allReduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + else: + # NOTE: `axes` is an attribute until opset version 13. + if len(node.input) > 1: + axis = _parse_data(data[node.input[1]]) + else: + axis = next( + ( + attr.ints + for attr in node.attribute + if attr.name == "axes" + ), + None, + ) + keepdims = next( + ( + attr.i + for attr in node.attribute + if attr.name == "keepdims" + ), + 1, + ) != 0 + + tensors[node.output[0]] = self.handler.reduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + axis, + keepdims, + ) elif node.op_type == "AllReduceSum": tensors[node.output[0]] = self.handler.allReduceSum( tensors[node.input[0]], @@ -1044,8 +1072,11 @@ class OnnxStub: elif ty == backend.OpTypeId.Gather: axis = backend.gather_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) - elif ty == backend.OpTypeId.ReduceMean: - axes, keepdims = backend.reduce_mean_attrs_of(op) + elif ty in [ + backend.OpTypeId.ReduceMean, + backend.OpTypeId.ReduceSum + ]: + axes, keepdims = backend.reduce_attrs_of(op) inputs.append( ctx.push_data_input( name, "axes", TensorProto.INT64, [len(axes)], axes diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 79df0294..8e1587b9 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -337,6 +337,14 @@ class TestStringMethods(unittest.TestCase): "ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean" ) make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced])) + + def test_reduce_sum(self): + data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) + reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1]) + reduceSum = make_node( + "ReduceSum", ["data"], ["reduced"], keepdims=1, name="reduceSum" + ) + make_and_import_model(make_graph([reduceSum], "reduceSum", [data], [reduced])) def test_slice(self): data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index d2f54b2d..fdceba62 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -12,7 +12,7 @@ #include "operators/matmul.h" #include "operators/pad.h" #include "operators/pooling.h" -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "operators/reshape.h" #include "operators/slice.h" #include "operators/softmax.h" @@ -302,18 +302,23 @@ Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices, } } -Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced, - const optional> &axes, - bool keepdims) { - if (reduced) { - g->addOpWithOutputs(std::move(data), reduced, axes, - keepdims); - return reduced; - } else { - return g->addOp(std::move(data), reduced, axes, keepdims) - ->getOutput(); +#define DEFINE_REDUCE_METHOD(name, obj) \ + Tensor GraphHandlerObj::name(Tensor data, Tensor reduced, \ + const optional> &axes, \ + bool keepdims) { \ + if (reduced) { \ + g->addOpWithOutputs<_CAT(obj, Obj)>(std::move(data), reduced, \ + axes, keepdims); \ + return reduced; \ + } else { \ + return g \ + ->addOp<_CAT(obj, Obj)>(std::move(data), reduced, axes, \ + keepdims) \ + ->getOutput(); \ + } \ } -} +DEFINE_REDUCE_METHOD(reduceMean, ReduceMean) +DEFINE_REDUCE_METHOD(reduceSum, ReduceSum) Tensor GraphHandlerObj::slice(Tensor input, Tensor output, const vector &starts, diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 5033a191..0bdfdcf9 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -8,7 +8,7 @@ #include "operators/matmul.h" #include "operators/pad.h" #include "operators/pooling.h" -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "operators/reshape.h" #include "operators/split.h" #include "operators/transpose.h" @@ -90,6 +90,7 @@ void export_values(py::module &m) { .VALUE(OpType, Gather) .VALUE(OpType, GatherElements) .VALUE(OpType, ReduceMean) + .VALUE(OpType, ReduceSum) .VALUE(OpType, Reshape) .VALUE(OpType, Flatten) .VALUE(OpType, Identity) @@ -219,12 +220,13 @@ clip_attrs_of(Operator op) { return std::make_tuple(clip->getMin(), clip->getMax()); } -static std::tuple, bool> reduce_mean_attrs_of(Operator op) { - IT_ASSERT(op->getOpType() == OpType::ReduceMean); - auto reduce_mean = dynamic_cast(op.get()); - auto &set = reduce_mean->getAxes(); +static std::tuple, bool> reduce_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::ReduceMean || + op->getOpType() == OpType::ReduceSum); + auto reduce = dynamic_cast(op.get()); + auto &set = reduce->getAxes(); return std::make_tuple(vector(set.begin(), set.end()), - reduce_mean->getKeepDims()); + reduce->getKeepDims()); } static int concat_axis_of(Operator op) { @@ -319,7 +321,7 @@ void export_functions(py::module &m) { .FUNCTION(batch_norm_attrs_of) .FUNCTION(pool_attrs_of) .FUNCTION(clip_attrs_of) - .FUNCTION(reduce_mean_attrs_of) + .FUNCTION(reduce_attrs_of) .FUNCTION(tensor_dtype) .FUNCTION(reshape_shape_of) .FUNCTION(expand_shape_of) @@ -497,7 +499,8 @@ void init_graph_builder(py::module &m) { .def("split", &Handler::split, policy::move) .def("gather", &Handler::gather, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move) - .def("reduce_mean", &Handler::reduceMean, policy::move) + .def("reduceMean", &Handler::reduceMean, policy::move) + .def("reduceSum", &Handler::reduceSum, policy::move) .def("slice", &Handler::slice, policy::move) .def("pad", &Handler::pad, policy::move) .def("allReduceSum", &Handler::allReduceSum, policy::move) diff --git a/src/kernels/cuda/reduce_mean.cc b/src/kernels/cuda/reduce.cc similarity index 87% rename from src/kernels/cuda/reduce_mean.cc rename to src/kernels/cuda/reduce.cc index 6ae357c8..840a572f 100644 --- a/src/kernels/cuda/reduce_mean.cc +++ b/src/kernels/cuda/reduce.cc @@ -1,12 +1,14 @@ -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" namespace infini { -class ReduceMeanCudnn : public CudaKernelWithoutConfig { +class ReduceCudnnBase : public CudaKernelWithoutConfig { + virtual cudnnReduceTensorOp_t getReduceOp() const = 0; + void compute(const Operator &_op, const RuntimeObj *_context) const override { - auto op = as(_op); + auto op = as(_op); auto input = op->getInputs(0); auto output = op->getOutput(); auto context = dynamic_cast(_context); @@ -71,7 +73,7 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig { cudnnReduceTensorDescriptor_t reduceDesc; checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc)); checkCudnnError(cudnnSetReduceTensorDescriptor( - reduceDesc, CUDNN_REDUCE_TENSOR_AVG, CUDNN_DATA_FLOAT, + reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES)); @@ -106,6 +108,20 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig { } }; +class ReduceMeanCudnn : public ReduceCudnnBase { + cudnnReduceTensorOp_t getReduceOp() const override { + return CUDNN_REDUCE_TENSOR_AVG; + } +}; + +class ReduceSumCudnn : public ReduceCudnnBase { + cudnnReduceTensorOp_t getReduceOp() const override { + return CUDNN_REDUCE_TENSOR_ADD; + } +}; + REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32, ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32, + ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32"); }; // namespace infini diff --git a/src/operators/reduce_mean.cc b/src/operators/reduce.cc similarity index 67% rename from src/operators/reduce_mean.cc rename to src/operators/reduce.cc index cf801c59..1626cb15 100644 --- a/src/operators/reduce_mean.cc +++ b/src/operators/reduce.cc @@ -1,10 +1,11 @@ -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "utils/operator_utils.h" namespace infini { -ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, - const optional> &_axes, bool keepDims) - : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { +ReduceBaseObj::ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, + Tensor output, const optional> &_axes, + bool keepDims) + : OperatorObj(opType, {input}, {output}), keepDims(keepDims) { const auto size = input->getRank(); if (_axes) { for (auto idx : *_axes) { @@ -17,11 +18,11 @@ ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -bool ReduceMeanObj::isReduced(int idx) const { +bool ReduceBaseObj::isReduced(int idx) const { return axes.find(idx) != axes.end(); } -optional> ReduceMeanObj::inferShape(const TensorVec &inputs) { +optional> ReduceBaseObj::inferShape(const TensorVec &inputs) { auto dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); @@ -43,10 +44,9 @@ optional> ReduceMeanObj::inferShape(const TensorVec &inputs) { } } -std::string ReduceMeanObj::toString() const { +std::string ReduceBaseObj::toString() const { std::ostringstream os; - os << "ReduceMean" - << "[" << getGuid() << "]"; + os << type.toString() << "[" << getGuid() << "]"; os << "("; os << vecToString(inputs[0]->getDims()) << ","; @@ -66,7 +66,7 @@ std::string ReduceMeanObj::toString() const { return os.str(); } -vector ReduceMeanObj::getWorkloadVector() const { +vector ReduceBaseObj::getWorkloadVector() const { vector ret = inputs[0]->getDims(); ret.emplace(ret.begin(), type.underlying()); ret.emplace_back((int)keepDims); @@ -74,9 +74,18 @@ vector ReduceMeanObj::getWorkloadVector() const { return ret; } -vector ReduceMeanObj::getOpAttrVector() const { +vector ReduceBaseObj::getOpAttrVector() const { vector ret = {type.underlying(), (int)keepDims}; ret.insert(ret.end(), axes.begin(), axes.end()); return ret; } + +ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, + const optional> &_axes, bool keepDims) + : ReduceBaseObj(graph, OpType::ReduceMean, input, output, _axes, keepDims) { +} + +ReduceSumObj::ReduceSumObj(GraphObj *graph, Tensor input, Tensor output, + const optional> &_axes, bool keepDims) + : ReduceBaseObj(graph, OpType::ReduceSum, input, output, _axes, keepDims) {} } // namespace infini diff --git a/test/core/test_graph_replace.cc b/test/core/test_graph_replace.cc index cada8860..fdecd3ce 100644 --- a/test/core/test_graph_replace.cc +++ b/test/core/test_graph_replace.cc @@ -7,7 +7,7 @@ #include "operators/extend.h" #include "operators/pad.h" #include "operators/pooling.h" -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "operators/slice.h" #include "operators/split.h" #include "operators/unary.h" diff --git a/test/kernels/cuda/test_cuda_reduce.cc b/test/kernels/cuda/test_cuda_reduce.cc new file mode 100644 index 00000000..9ce31032 --- /dev/null +++ b/test/kernels/cuda/test_cuda_reduce.cc @@ -0,0 +1,83 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/reduce.h" + +#include "test.h" + +namespace infini { + +template +void test_reduce(const Shape &shape, const vector &data, + const optional> &axis, bool keepDims, + const vector &ExpectData) { + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = make_ref(shape, DataType::Float32, cpuRuntime); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, axis, keepDims); + + // allocate CUDA memory + g->dataMalloc(); + i->copyin(data); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(ExpectData)); +} + +TEST(CUDA_ReduceMean, run) { + test_reduce( + Shape{3, 2, 2}, vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, true, vector{18.25}); + test_reduce( + Shape{1, 3, 2, 2, 1}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt, + false, vector{18.25}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{5, 6, 17, 18}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{5, 6, 17, 18}); +} + +TEST(CUDA_ReduceSum, run) { + test_reduce(Shape{3, 2, 2}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, true, vector{12}); + test_reduce(Shape{1, 3, 2, 2, 1}, + vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + std::nullopt, false, vector{12}); + + test_reduce( + Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{30, 36, 102, 108}); + test_reduce( + Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{30, 36, 102, 108}); +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_reduce_mean.cc b/test/kernels/cuda/test_cuda_reduce_mean.cc deleted file mode 100644 index 2ad672a7..00000000 --- a/test/kernels/cuda/test_cuda_reduce_mean.cc +++ /dev/null @@ -1,61 +0,0 @@ -#include "core/graph.h" -#include "core/kernel.h" -#include "core/runtime.h" -#include "cuda/cuda_runtime.h" -#include "cuda/cuda_utility.h" -#include "operators/reduce_mean.h" - -#include "test.h" - -namespace infini { - -void test_reducemean(const Shape &shape, const vector &data, - const optional> &axis, bool keepDims, - const vector &ExpectData) { - Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); - - // Build input data on CPU - Tensor icpu = make_ref(shape, DataType::Float32, cpuRuntime); - - // Build CUDA graph - Graph g = make_ref(cudaRuntime); - auto i = g->cloneTensor(icpu); - auto op = g->addOp(i, nullptr, axis, keepDims); - - // allocate CUDA memory - g->dataMalloc(); - i->copyin(data); - - // Execute on CUDA - cudaRuntime->run(g); - - // clone CUDA output to CPU - auto o = op->getOutput(); - auto ocpu = o->clone(cpuRuntime); - - // check results on CPU - EXPECT_TRUE(ocpu->equalData(ExpectData)); -} - -TEST(CUDA_ReduceMean, run) { - test_reducemean(Shape{3, 2, 2}, - vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, - std::nullopt, true, vector{18.25}); - test_reducemean(Shape{1, 3, 2, 2, 1}, - vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, - std::nullopt, false, vector{18.25}); - - test_reducemean(Shape{2, 3, 2, 2}, - vector{0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23}, - vector{1, 2}, false, vector{5, 6, 17, 18}); - test_reducemean(Shape{2, 3, 2, 2, 1}, - vector{0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23}, - vector{1, 2}, true, vector{5, 6, 17, 18}); -} - -} // namespace infini diff --git a/test/operators/test_reduce_mean.cc b/test/operators/test_reduce.cc similarity index 68% rename from test/operators/test_reduce_mean.cc rename to test/operators/test_reduce.cc index 336d4018..83269bb0 100644 --- a/test/operators/test_reduce_mean.cc +++ b/test/operators/test_reduce.cc @@ -1,51 +1,55 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" -#include "operators/reduce_mean.h" +#include "operators/reduce.h" #include "test.h" namespace infini { -TEST(ReduceMean, ShapeInference) { +template void testShapeInference() { Runtime runtime = NativeCpuRuntimeObj::getInstance(); { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = g->addOp(i, nullptr, std::nullopt, true); + auto op = g->addOp(i, nullptr, std::nullopt, true); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 1})); } { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = g->addOp(i, nullptr, vector{1, 3}, true); + auto op = g->addOp(i, nullptr, vector{1, 3}, true); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1})); } { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = g->addOp(i, nullptr, vector{-3, 3}, true); + auto op = g->addOp(i, nullptr, vector{-3, 3}, true); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1})); } { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = g->addOp(i, nullptr, std::nullopt, false); + auto op = g->addOp(i, nullptr, std::nullopt, false); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1})); } { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = g->addOp(i, nullptr, vector{1, 3}, false); + auto op = g->addOp(i, nullptr, vector{1, 3}, false); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3})); } { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); - auto op = - g->addOp(i, nullptr, vector{-3, 3}, false); + auto op = g->addOp(i, nullptr, vector{-3, 3}, false); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3})); } } +TEST(ReduceMean, ShapeInference) { + testShapeInference(); + testShapeInference(); +} + } // namespace infini