Add ReduceSum op and kernel (#160)

* Add reduceSum op and kernel

* fix merge and format

* Reduce: reuse cat macro, add doc string

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
PanZezhong1725 2023-11-24 09:29:58 +08:00 committed by GitHub
parent 595a9906d2
commit 6ece3f4a77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 235 additions and 121 deletions

View File

@ -73,6 +73,8 @@ class GraphHandlerObj {
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
Tensor reduceMean(Tensor data, Tensor reduced,
const optional<vector<int>> &axes, bool keepdims);
Tensor reduceSum(Tensor data, Tensor reduced,
const optional<vector<int>> &axes, bool keepdims);
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
const vector<int> &ends, const optional<vector<int>> &axes,
const optional<vector<int>> &steps);

View File

@ -3,26 +3,29 @@
namespace infini {
/**
* @brief Compute the mean of input tensor's elements along certain axes.
* @brief Compute the reduction of input tensor's elements along certain axes.
*
*/
class ReduceMeanObj : public OperatorObj {
class ReduceBaseObj : public OperatorObj {
protected:
set<int> axes; // axis to reduce
bool keepDims;
public:
/**
* @brief Construct a new ReduceMean object.
* @brief Construct a new Reduce object.
*
* @param graph The computation graph that this operator belongs to.
* @param opType The operation type. Should be a Reduce operation.
* @param input The input tensor.
* @param output The output tensor.
* @param axes Axes to reduce.
* @param keepDims Keep the reduced dimensions or not.
*/
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &axes, bool keepDims = true);
OP_CLONE(ReduceMeanObj);
ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, Tensor output,
const optional<vector<int>> &axes, bool keepDims);
virtual ~ReduceBaseObj() {}
OP_CLONE(ReduceBaseObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
@ -38,4 +41,15 @@ class ReduceMeanObj : public OperatorObj {
vector<int> getOpAttrVector() const override;
};
class ReduceMeanObj : public ReduceBaseObj {
public:
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &axes, bool keepDims = true);
};
class ReduceSumObj : public ReduceBaseObj {
public:
ReduceSumObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &axes, bool keepDims = true);
};
} // namespace infini

View File

@ -604,7 +604,7 @@ class OnnxStub:
),
)
elif node.op_type == "ReduceMean":
tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.output[0]] = self.handler.reduceMean(
tensors[node.input[0]],
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
@ -678,12 +678,40 @@ class OnnxStub:
next((attr.i for attr in node.attribute if attr.name == "to")),
)
elif node.op_type == "ReduceSum":
# ReduceSum is only implemented as allReduceSum.
assert any(attr.name == "communicator" for attr in node.attribute)
tensors[node.output[0]] = self.handler.allReduceSum(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
if any(attr.name == "communicator" for attr in node.attribute):
# ReduceSum with communicator is treated as allReduceSum.
tensors[node.output[0]] = self.handler.allReduceSum(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
else:
# NOTE: `axes` is an attribute until opset version 13.
if len(node.input) > 1:
axis = _parse_data(data[node.input[1]])
else:
axis = next(
(
attr.ints
for attr in node.attribute
if attr.name == "axes"
),
None,
)
keepdims = next(
(
attr.i
for attr in node.attribute
if attr.name == "keepdims"
),
1,
) != 0
tensors[node.output[0]] = self.handler.reduceSum(
tensors[node.input[0]],
tensors.get(node.output[0]),
axis,
keepdims,
)
elif node.op_type == "AllReduceSum":
tensors[node.output[0]] = self.handler.allReduceSum(
tensors[node.input[0]],
@ -1044,8 +1072,11 @@ class OnnxStub:
elif ty == backend.OpTypeId.Gather:
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpTypeId.ReduceMean:
axes, keepdims = backend.reduce_mean_attrs_of(op)
elif ty in [
backend.OpTypeId.ReduceMean,
backend.OpTypeId.ReduceSum
]:
axes, keepdims = backend.reduce_attrs_of(op)
inputs.append(
ctx.push_data_input(
name, "axes", TensorProto.INT64, [len(axes)], axes

View File

@ -337,6 +337,14 @@ class TestStringMethods(unittest.TestCase):
"ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean"
)
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
def test_reduce_sum(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
reduceSum = make_node(
"ReduceSum", ["data"], ["reduced"], keepdims=1, name="reduceSum"
)
make_and_import_model(make_graph([reduceSum], "reduceSum", [data], [reduced]))
def test_slice(self):
data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162])

View File

@ -12,7 +12,7 @@
#include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
#include "operators/reduce.h"
#include "operators/reshape.h"
#include "operators/slice.h"
#include "operators/softmax.h"
@ -302,18 +302,23 @@ Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices,
}
}
Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
const optional<vector<int>> &axes,
bool keepdims) {
if (reduced) {
g->addOpWithOutputs<ReduceMeanObj>(std::move(data), reduced, axes,
keepdims);
return reduced;
} else {
return g->addOp<ReduceMeanObj>(std::move(data), reduced, axes, keepdims)
->getOutput();
#define DEFINE_REDUCE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor data, Tensor reduced, \
const optional<vector<int>> &axes, \
bool keepdims) { \
if (reduced) { \
g->addOpWithOutputs<_CAT(obj, Obj)>(std::move(data), reduced, \
axes, keepdims); \
return reduced; \
} else { \
return g \
->addOp<_CAT(obj, Obj)>(std::move(data), reduced, axes, \
keepdims) \
->getOutput(); \
} \
}
}
DEFINE_REDUCE_METHOD(reduceMean, ReduceMean)
DEFINE_REDUCE_METHOD(reduceSum, ReduceSum)
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
const vector<int> &starts,

View File

@ -8,7 +8,7 @@
#include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
#include "operators/reduce.h"
#include "operators/reshape.h"
#include "operators/split.h"
#include "operators/transpose.h"
@ -90,6 +90,7 @@ void export_values(py::module &m) {
.VALUE(OpType, Gather)
.VALUE(OpType, GatherElements)
.VALUE(OpType, ReduceMean)
.VALUE(OpType, ReduceSum)
.VALUE(OpType, Reshape)
.VALUE(OpType, Flatten)
.VALUE(OpType, Identity)
@ -219,12 +220,13 @@ clip_attrs_of(Operator op) {
return std::make_tuple(clip->getMin(), clip->getMax());
}
static std::tuple<vector<int>, bool> reduce_mean_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
auto reduce_mean = dynamic_cast<const ReduceMeanObj *>(op.get());
auto &set = reduce_mean->getAxes();
static std::tuple<vector<int>, bool> reduce_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ReduceMean ||
op->getOpType() == OpType::ReduceSum);
auto reduce = dynamic_cast<const ReduceBaseObj *>(op.get());
auto &set = reduce->getAxes();
return std::make_tuple(vector(set.begin(), set.end()),
reduce_mean->getKeepDims());
reduce->getKeepDims());
}
static int concat_axis_of(Operator op) {
@ -319,7 +321,7 @@ void export_functions(py::module &m) {
.FUNCTION(batch_norm_attrs_of)
.FUNCTION(pool_attrs_of)
.FUNCTION(clip_attrs_of)
.FUNCTION(reduce_mean_attrs_of)
.FUNCTION(reduce_attrs_of)
.FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of)
.FUNCTION(expand_shape_of)
@ -497,7 +499,8 @@ void init_graph_builder(py::module &m) {
.def("split", &Handler::split, policy::move)
.def("gather", &Handler::gather, policy::move)
.def("gatherElements", &Handler::gatherElements, policy::move)
.def("reduce_mean", &Handler::reduceMean, policy::move)
.def("reduceMean", &Handler::reduceMean, policy::move)
.def("reduceSum", &Handler::reduceSum, policy::move)
.def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move)
.def("allReduceSum", &Handler::allReduceSum, policy::move)

View File

@ -1,12 +1,14 @@
#include "operators/reduce_mean.h"
#include "operators/reduce.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class ReduceMeanCudnn : public CudaKernelWithoutConfig {
class ReduceCudnnBase : public CudaKernelWithoutConfig {
virtual cudnnReduceTensorOp_t getReduceOp() const = 0;
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceMeanObj>(_op);
auto op = as<ReduceBaseObj>(_op);
auto input = op->getInputs(0);
auto output = op->getOutput();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
@ -71,7 +73,7 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig {
cudnnReduceTensorDescriptor_t reduceDesc;
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
checkCudnnError(cudnnSetReduceTensorDescriptor(
reduceDesc, CUDNN_REDUCE_TENSOR_AVG, CUDNN_DATA_FLOAT,
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
CUDNN_32BIT_INDICES));
@ -106,6 +108,20 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig {
}
};
class ReduceMeanCudnn : public ReduceCudnnBase {
cudnnReduceTensorOp_t getReduceOp() const override {
return CUDNN_REDUCE_TENSOR_AVG;
}
};
class ReduceSumCudnn : public ReduceCudnnBase {
cudnnReduceTensorOp_t getReduceOp() const override {
return CUDNN_REDUCE_TENSOR_ADD;
}
};
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32,
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32,
ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32");
}; // namespace infini

View File

@ -1,10 +1,11 @@
#include "operators/reduce_mean.h"
#include "operators/reduce.h"
#include "utils/operator_utils.h"
namespace infini {
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &_axes, bool keepDims)
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
ReduceBaseObj::ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input,
Tensor output, const optional<vector<int>> &_axes,
bool keepDims)
: OperatorObj(opType, {input}, {output}), keepDims(keepDims) {
const auto size = input->getRank();
if (_axes) {
for (auto idx : *_axes) {
@ -17,11 +18,11 @@ ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph));
}
bool ReduceMeanObj::isReduced(int idx) const {
bool ReduceBaseObj::isReduced(int idx) const {
return axes.find(idx) != axes.end();
}
optional<vector<Shape>> ReduceMeanObj::inferShape(const TensorVec &inputs) {
optional<vector<Shape>> ReduceBaseObj::inferShape(const TensorVec &inputs) {
auto dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank();
@ -43,10 +44,9 @@ optional<vector<Shape>> ReduceMeanObj::inferShape(const TensorVec &inputs) {
}
}
std::string ReduceMeanObj::toString() const {
std::string ReduceBaseObj::toString() const {
std::ostringstream os;
os << "ReduceMean"
<< "[" << getGuid() << "]";
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
@ -66,7 +66,7 @@ std::string ReduceMeanObj::toString() const {
return os.str();
}
vector<int> ReduceMeanObj::getWorkloadVector() const {
vector<int> ReduceBaseObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), type.underlying());
ret.emplace_back((int)keepDims);
@ -74,9 +74,18 @@ vector<int> ReduceMeanObj::getWorkloadVector() const {
return ret;
}
vector<int> ReduceMeanObj::getOpAttrVector() const {
vector<int> ReduceBaseObj::getOpAttrVector() const {
vector<int> ret = {type.underlying(), (int)keepDims};
ret.insert(ret.end(), axes.begin(), axes.end());
return ret;
}
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &_axes, bool keepDims)
: ReduceBaseObj(graph, OpType::ReduceMean, input, output, _axes, keepDims) {
}
ReduceSumObj::ReduceSumObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &_axes, bool keepDims)
: ReduceBaseObj(graph, OpType::ReduceSum, input, output, _axes, keepDims) {}
} // namespace infini

View File

@ -7,7 +7,7 @@
#include "operators/extend.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
#include "operators/reduce.h"
#include "operators/slice.h"
#include "operators/split.h"
#include "operators/unary.h"

View File

@ -0,0 +1,83 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/reduce.h"
#include "test.h"
namespace infini {
template <typename ReduceObjT>
void test_reduce(const Shape &shape, const vector<float> &data,
const optional<const vector<int>> &axis, bool keepDims,
const vector<float> &ExpectData) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Build input data on CPU
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
// Build CUDA graph
Graph g = make_ref<GraphObj>(cudaRuntime);
auto i = g->cloneTensor(icpu);
auto op = g->addOp<ReduceObjT>(i, nullptr, axis, keepDims);
// allocate CUDA memory
g->dataMalloc();
i->copyin(data);
// Execute on CUDA
cudaRuntime->run(g);
// clone CUDA output to CPU
auto o = op->getOutput();
auto ocpu = o->clone(cpuRuntime);
// check results on CPU
EXPECT_TRUE(ocpu->equalData(ExpectData));
}
TEST(CUDA_ReduceMean, run) {
test_reduce<ReduceMeanObj>(
Shape{3, 2, 2}, vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
std::nullopt, true, vector<float>{18.25});
test_reduce<ReduceMeanObj>(
Shape{1, 3, 2, 2, 1},
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt,
false, vector<float>{18.25});
test_reduce<ReduceMeanObj>(
Shape{2, 3, 2, 2},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
test_reduce<ReduceMeanObj>(
Shape{2, 3, 2, 2, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
}
TEST(CUDA_ReduceSum, run) {
test_reduce<ReduceSumObj>(Shape{3, 2, 2},
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
std::nullopt, true, vector<float>{12});
test_reduce<ReduceSumObj>(Shape{1, 3, 2, 2, 1},
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
std::nullopt, false, vector<float>{12});
test_reduce<ReduceSumObj>(
Shape{2, 3, 2, 2},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, false, vector<float>{30, 36, 102, 108});
test_reduce<ReduceSumObj>(
Shape{2, 3, 2, 2, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, true, vector<float>{30, 36, 102, 108});
}
} // namespace infini

View File

@ -1,61 +0,0 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/reduce_mean.h"
#include "test.h"
namespace infini {
void test_reducemean(const Shape &shape, const vector<float> &data,
const optional<const vector<int>> &axis, bool keepDims,
const vector<float> &ExpectData) {
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Build input data on CPU
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
// Build CUDA graph
Graph g = make_ref<GraphObj>(cudaRuntime);
auto i = g->cloneTensor(icpu);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, axis, keepDims);
// allocate CUDA memory
g->dataMalloc();
i->copyin(data);
// Execute on CUDA
cudaRuntime->run(g);
// clone CUDA output to CPU
auto o = op->getOutput();
auto ocpu = o->clone(cpuRuntime);
// check results on CPU
EXPECT_TRUE(ocpu->equalData(ExpectData));
}
TEST(CUDA_ReduceMean, run) {
test_reducemean(Shape{3, 2, 2},
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
std::nullopt, true, vector<float>{18.25});
test_reducemean(Shape{1, 3, 2, 2, 1},
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
std::nullopt, false, vector<float>{18.25});
test_reducemean(Shape{2, 3, 2, 2},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
test_reducemean(Shape{2, 3, 2, 2, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
}
} // namespace infini

View File

@ -1,51 +1,55 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/reduce_mean.h"
#include "operators/reduce.h"
#include "test.h"
namespace infini {
TEST(ReduceMean, ShapeInference) {
template <typename ReduceObjT> void testShapeInference() {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, true);
auto op = g->addOp<ReduceObjT>(i, nullptr, std::nullopt, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true);
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{1, 3}, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, true);
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{-3, 3}, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, false);
auto op = g->addOp<ReduceObjT>(i, nullptr, std::nullopt, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false);
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{1, 3}, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op =
g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, false);
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{-3, 3}, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
}
}
TEST(ReduceMean, ShapeInference) {
testShapeInference<ReduceMeanObj>();
testShapeInference<ReduceSumObj>();
}
} // namespace infini