forked from jiuyuan/InfiniTensor
Add ReduceSum op and kernel (#160)
* Add reduceSum op and kernel * fix merge and format * Reduce: reuse cat macro, add doc string --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
595a9906d2
commit
6ece3f4a77
|
@ -73,6 +73,8 @@ class GraphHandlerObj {
|
|||
Tensor gatherElements(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes, bool keepdims);
|
||||
Tensor reduceSum(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes, bool keepdims);
|
||||
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
||||
const vector<int> &ends, const optional<vector<int>> &axes,
|
||||
const optional<vector<int>> &steps);
|
||||
|
|
|
@ -3,26 +3,29 @@
|
|||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Compute the mean of input tensor's elements along certain axes.
|
||||
* @brief Compute the reduction of input tensor's elements along certain axes.
|
||||
*
|
||||
*/
|
||||
class ReduceMeanObj : public OperatorObj {
|
||||
class ReduceBaseObj : public OperatorObj {
|
||||
protected:
|
||||
set<int> axes; // axis to reduce
|
||||
bool keepDims;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new ReduceMean object.
|
||||
* @brief Construct a new Reduce object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param opType The operation type. Should be a Reduce operation.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axes Axes to reduce.
|
||||
* @param keepDims Keep the reduced dimensions or not.
|
||||
*/
|
||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
OP_CLONE(ReduceMeanObj);
|
||||
ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims);
|
||||
virtual ~ReduceBaseObj() {}
|
||||
OP_CLONE(ReduceBaseObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
@ -38,4 +41,15 @@ class ReduceMeanObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class ReduceMeanObj : public ReduceBaseObj {
|
||||
public:
|
||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
};
|
||||
|
||||
class ReduceSumObj : public ReduceBaseObj {
|
||||
public:
|
||||
ReduceSumObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
};
|
||||
} // namespace infini
|
|
@ -604,7 +604,7 @@ class OnnxStub:
|
|||
),
|
||||
)
|
||||
elif node.op_type == "ReduceMean":
|
||||
tensors[node.output[0]] = self.handler.reduce_mean(
|
||||
tensors[node.output[0]] = self.handler.reduceMean(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
# NOTE(constroy): `axes` is an attribute until opset version 13.
|
||||
|
@ -678,12 +678,40 @@ class OnnxStub:
|
|||
next((attr.i for attr in node.attribute if attr.name == "to")),
|
||||
)
|
||||
elif node.op_type == "ReduceSum":
|
||||
# ReduceSum is only implemented as allReduceSum.
|
||||
assert any(attr.name == "communicator" for attr in node.attribute)
|
||||
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
if any(attr.name == "communicator" for attr in node.attribute):
|
||||
# ReduceSum with communicator is treated as allReduceSum.
|
||||
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
else:
|
||||
# NOTE: `axes` is an attribute until opset version 13.
|
||||
if len(node.input) > 1:
|
||||
axis = _parse_data(data[node.input[1]])
|
||||
else:
|
||||
axis = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
),
|
||||
None,
|
||||
)
|
||||
keepdims = next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "keepdims"
|
||||
),
|
||||
1,
|
||||
) != 0
|
||||
|
||||
tensors[node.output[0]] = self.handler.reduceSum(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
axis,
|
||||
keepdims,
|
||||
)
|
||||
elif node.op_type == "AllReduceSum":
|
||||
tensors[node.output[0]] = self.handler.allReduceSum(
|
||||
tensors[node.input[0]],
|
||||
|
@ -1044,8 +1072,11 @@ class OnnxStub:
|
|||
elif ty == backend.OpTypeId.Gather:
|
||||
axis = backend.gather_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpTypeId.ReduceMean:
|
||||
axes, keepdims = backend.reduce_mean_attrs_of(op)
|
||||
elif ty in [
|
||||
backend.OpTypeId.ReduceMean,
|
||||
backend.OpTypeId.ReduceSum
|
||||
]:
|
||||
axes, keepdims = backend.reduce_attrs_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name, "axes", TensorProto.INT64, [len(axes)], axes
|
||||
|
|
|
@ -337,6 +337,14 @@ class TestStringMethods(unittest.TestCase):
|
|||
"ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean"
|
||||
)
|
||||
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
|
||||
|
||||
def test_reduce_sum(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
|
||||
reduceSum = make_node(
|
||||
"ReduceSum", ["data"], ["reduced"], keepdims=1, name="reduceSum"
|
||||
)
|
||||
make_and_import_model(make_graph([reduceSum], "reduceSum", [data], [reduced]))
|
||||
|
||||
def test_slice(self):
|
||||
data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162])
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
|
@ -302,18 +302,23 @@ Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes,
|
||||
bool keepdims) {
|
||||
if (reduced) {
|
||||
g->addOpWithOutputs<ReduceMeanObj>(std::move(data), reduced, axes,
|
||||
keepdims);
|
||||
return reduced;
|
||||
} else {
|
||||
return g->addOp<ReduceMeanObj>(std::move(data), reduced, axes, keepdims)
|
||||
->getOutput();
|
||||
#define DEFINE_REDUCE_METHOD(name, obj) \
|
||||
Tensor GraphHandlerObj::name(Tensor data, Tensor reduced, \
|
||||
const optional<vector<int>> &axes, \
|
||||
bool keepdims) { \
|
||||
if (reduced) { \
|
||||
g->addOpWithOutputs<_CAT(obj, Obj)>(std::move(data), reduced, \
|
||||
axes, keepdims); \
|
||||
return reduced; \
|
||||
} else { \
|
||||
return g \
|
||||
->addOp<_CAT(obj, Obj)>(std::move(data), reduced, axes, \
|
||||
keepdims) \
|
||||
->getOutput(); \
|
||||
} \
|
||||
}
|
||||
}
|
||||
DEFINE_REDUCE_METHOD(reduceMean, ReduceMean)
|
||||
DEFINE_REDUCE_METHOD(reduceSum, ReduceSum)
|
||||
|
||||
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
|
||||
const vector<int> &starts,
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/transpose.h"
|
||||
|
@ -90,6 +90,7 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Gather)
|
||||
.VALUE(OpType, GatherElements)
|
||||
.VALUE(OpType, ReduceMean)
|
||||
.VALUE(OpType, ReduceSum)
|
||||
.VALUE(OpType, Reshape)
|
||||
.VALUE(OpType, Flatten)
|
||||
.VALUE(OpType, Identity)
|
||||
|
@ -219,12 +220,13 @@ clip_attrs_of(Operator op) {
|
|||
return std::make_tuple(clip->getMin(), clip->getMax());
|
||||
}
|
||||
|
||||
static std::tuple<vector<int>, bool> reduce_mean_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
|
||||
auto reduce_mean = dynamic_cast<const ReduceMeanObj *>(op.get());
|
||||
auto &set = reduce_mean->getAxes();
|
||||
static std::tuple<vector<int>, bool> reduce_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ReduceMean ||
|
||||
op->getOpType() == OpType::ReduceSum);
|
||||
auto reduce = dynamic_cast<const ReduceBaseObj *>(op.get());
|
||||
auto &set = reduce->getAxes();
|
||||
return std::make_tuple(vector(set.begin(), set.end()),
|
||||
reduce_mean->getKeepDims());
|
||||
reduce->getKeepDims());
|
||||
}
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
|
@ -319,7 +321,7 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(batch_norm_attrs_of)
|
||||
.FUNCTION(pool_attrs_of)
|
||||
.FUNCTION(clip_attrs_of)
|
||||
.FUNCTION(reduce_mean_attrs_of)
|
||||
.FUNCTION(reduce_attrs_of)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(expand_shape_of)
|
||||
|
@ -497,7 +499,8 @@ void init_graph_builder(py::module &m) {
|
|||
.def("split", &Handler::split, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("gatherElements", &Handler::gatherElements, policy::move)
|
||||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||
.def("reduceMean", &Handler::reduceMean, policy::move)
|
||||
.def("reduceSum", &Handler::reduceSum, policy::move)
|
||||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("allReduceSum", &Handler::allReduceSum, policy::move)
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ReduceMeanCudnn : public CudaKernelWithoutConfig {
|
||||
class ReduceCudnnBase : public CudaKernelWithoutConfig {
|
||||
virtual cudnnReduceTensorOp_t getReduceOp() const = 0;
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceMeanObj>(_op);
|
||||
auto op = as<ReduceBaseObj>(_op);
|
||||
auto input = op->getInputs(0);
|
||||
auto output = op->getOutput();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
@ -71,7 +73,7 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig {
|
|||
cudnnReduceTensorDescriptor_t reduceDesc;
|
||||
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
|
||||
checkCudnnError(cudnnSetReduceTensorDescriptor(
|
||||
reduceDesc, CUDNN_REDUCE_TENSOR_AVG, CUDNN_DATA_FLOAT,
|
||||
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT,
|
||||
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
|
||||
CUDNN_32BIT_INDICES));
|
||||
|
||||
|
@ -106,6 +108,20 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class ReduceMeanCudnn : public ReduceCudnnBase {
|
||||
cudnnReduceTensorOp_t getReduceOp() const override {
|
||||
return CUDNN_REDUCE_TENSOR_AVG;
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceSumCudnn : public ReduceCudnnBase {
|
||||
cudnnReduceTensorOp_t getReduceOp() const override {
|
||||
return CUDNN_REDUCE_TENSOR_ADD;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32,
|
||||
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32,
|
||||
ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32");
|
||||
}; // namespace infini
|
|
@ -1,10 +1,11 @@
|
|||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &_axes, bool keepDims)
|
||||
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
||||
ReduceBaseObj::ReduceBaseObj(GraphObj *graph, OpType opType, Tensor input,
|
||||
Tensor output, const optional<vector<int>> &_axes,
|
||||
bool keepDims)
|
||||
: OperatorObj(opType, {input}, {output}), keepDims(keepDims) {
|
||||
const auto size = input->getRank();
|
||||
if (_axes) {
|
||||
for (auto idx : *_axes) {
|
||||
|
@ -17,11 +18,11 @@ ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
bool ReduceMeanObj::isReduced(int idx) const {
|
||||
bool ReduceBaseObj::isReduced(int idx) const {
|
||||
return axes.find(idx) != axes.end();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ReduceMeanObj::inferShape(const TensorVec &inputs) {
|
||||
optional<vector<Shape>> ReduceBaseObj::inferShape(const TensorVec &inputs) {
|
||||
auto dims = inputs[0]->getDims();
|
||||
auto rank = inputs[0]->getRank();
|
||||
|
||||
|
@ -43,10 +44,9 @@ optional<vector<Shape>> ReduceMeanObj::inferShape(const TensorVec &inputs) {
|
|||
}
|
||||
}
|
||||
|
||||
std::string ReduceMeanObj::toString() const {
|
||||
std::string ReduceBaseObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "ReduceMean"
|
||||
<< "[" << getGuid() << "]";
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
|
||||
|
@ -66,7 +66,7 @@ std::string ReduceMeanObj::toString() const {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> ReduceMeanObj::getWorkloadVector() const {
|
||||
vector<int> ReduceBaseObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
ret.emplace_back((int)keepDims);
|
||||
|
@ -74,9 +74,18 @@ vector<int> ReduceMeanObj::getWorkloadVector() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
vector<int> ReduceMeanObj::getOpAttrVector() const {
|
||||
vector<int> ReduceBaseObj::getOpAttrVector() const {
|
||||
vector<int> ret = {type.underlying(), (int)keepDims};
|
||||
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &_axes, bool keepDims)
|
||||
: ReduceBaseObj(graph, OpType::ReduceMean, input, output, _axes, keepDims) {
|
||||
}
|
||||
|
||||
ReduceSumObj::ReduceSumObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &_axes, bool keepDims)
|
||||
: ReduceBaseObj(graph, OpType::ReduceSum, input, output, _axes, keepDims) {}
|
||||
} // namespace infini
|
|
@ -7,7 +7,7 @@
|
|||
#include "operators/extend.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/unary.h"
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/reduce.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <typename ReduceObjT>
|
||||
void test_reduce(const Shape &shape, const vector<float> &data,
|
||||
const optional<const vector<int>> &axis, bool keepDims,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto i = g->cloneTensor(icpu);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, axis, keepDims);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
i->copyin(data);
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto o = op->getOutput();
|
||||
auto ocpu = o->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ocpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(CUDA_ReduceMean, run) {
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{3, 2, 2}, vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||
std::nullopt, true, vector<float>{18.25});
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{1, 3, 2, 2, 1},
|
||||
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, std::nullopt,
|
||||
false, vector<float>{18.25});
|
||||
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{2, 3, 2, 2},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
|
||||
test_reduce<ReduceMeanObj>(
|
||||
Shape{2, 3, 2, 2, 1},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
|
||||
}
|
||||
|
||||
TEST(CUDA_ReduceSum, run) {
|
||||
test_reduce<ReduceSumObj>(Shape{3, 2, 2},
|
||||
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
std::nullopt, true, vector<float>{12});
|
||||
test_reduce<ReduceSumObj>(Shape{1, 3, 2, 2, 1},
|
||||
vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
std::nullopt, false, vector<float>{12});
|
||||
|
||||
test_reduce<ReduceSumObj>(
|
||||
Shape{2, 3, 2, 2},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, false, vector<float>{30, 36, 102, 108});
|
||||
test_reduce<ReduceSumObj>(
|
||||
Shape{2, 3, 2, 2, 1},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, true, vector<float>{30, 36, 102, 108});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -1,61 +0,0 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_reducemean(const Shape &shape, const vector<float> &data,
|
||||
const optional<const vector<int>> &axis, bool keepDims,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto i = g->cloneTensor(icpu);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, axis, keepDims);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
i->copyin(data);
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto o = op->getOutput();
|
||||
auto ocpu = o->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ocpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(CUDA_ReduceMean, run) {
|
||||
test_reducemean(Shape{3, 2, 2},
|
||||
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||
std::nullopt, true, vector<float>{18.25});
|
||||
test_reducemean(Shape{1, 3, 2, 2, 1},
|
||||
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||
std::nullopt, false, vector<float>{18.25});
|
||||
|
||||
test_reducemean(Shape{2, 3, 2, 2},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
|
||||
test_reducemean(Shape{2, 3, 2, 2, 1},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -1,51 +1,55 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reduce.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(ReduceMean, ShapeInference) {
|
||||
template <typename ReduceObjT> void testShapeInference() {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, true);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, std::nullopt, true);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 1}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{1, 3}, true);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, true);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{-3, 3}, true);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, false);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, std::nullopt, false);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{1, 3}, false);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op =
|
||||
g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, false);
|
||||
auto op = g->addOp<ReduceObjT>(i, nullptr, vector<int>{-3, 3}, false);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ReduceMean, ShapeInference) {
|
||||
testShapeInference<ReduceMeanObj>();
|
||||
testShapeInference<ReduceSumObj>();
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue