forked from jiuyuan/InfiniTensor
feat: 前端支持 reduceMean 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
fb9d84dbb7
commit
62ceb78ae3
|
@ -59,7 +59,9 @@ class GraphHandlerObj {
|
||||||
Tensor flatten(Tensor s, Tensor y);
|
Tensor flatten(Tensor s, Tensor y);
|
||||||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||||
Tensor gather(Tensor input, Tensor indices, Tensor output, int axis);
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
|
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||||
|
const optional<vector<int>> &axes, bool keepdims);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -21,8 +21,7 @@ class ReduceMeanObj : public OperatorObj {
|
||||||
* @param keepDims Keep the reduced dimensions or not.
|
* @param keepDims Keep the reduced dimensions or not.
|
||||||
*/
|
*/
|
||||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const optional<const vector<int>> &axes,
|
const optional<vector<int>> &axes, bool keepDims = true);
|
||||||
bool keepDims = true);
|
|
||||||
OP_CLONE(ReduceMeanObj);
|
OP_CLONE(ReduceMeanObj);
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,14 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "ReduceMean":
|
||||||
|
tensors[node.output[0]] = handler.reduceMean(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||||
|
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||||
|
!= 0,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
|
@ -171,6 +171,14 @@ class TestStringMethods(unittest.TestCase):
|
||||||
)
|
)
|
||||||
make_and_import_model(make_graph([gather], "gather", [data, indices], [output]))
|
make_and_import_model(make_graph([gather], "gather", [data, indices], [output]))
|
||||||
|
|
||||||
|
def test_reduce_mean(self):
|
||||||
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||||
|
reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1])
|
||||||
|
reduceMean = make_node(
|
||||||
|
"ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean"
|
||||||
|
)
|
||||||
|
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int);
|
static DataType dtype_repr_convert(int);
|
||||||
|
@ -104,20 +104,33 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::gather(Tensor input, Tensor indices, Tensor output,
|
Tensor GraphHandlerObj::gather(Tensor data, Tensor indices, Tensor output,
|
||||||
int axis) {
|
int axis) {
|
||||||
if (output) {
|
if (output) {
|
||||||
g->addOpWithOutputs<GatherObj>(std::move(input), std::move(indices),
|
g->addOpWithOutputs<GatherObj>(std::move(data), std::move(indices),
|
||||||
output, axis);
|
output, axis);
|
||||||
return output;
|
return output;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<GatherObj>(std::move(input), std::move(indices), output,
|
->addOp<GatherObj>(std::move(data), std::move(indices), output,
|
||||||
axis)
|
axis)
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
||||||
|
const optional<vector<int>> &axes,
|
||||||
|
bool keepdims) {
|
||||||
|
if (reduced) {
|
||||||
|
g->addOpWithOutputs<ReduceMeanObj>(std::move(data), reduced, axes,
|
||||||
|
keepdims);
|
||||||
|
return reduced;
|
||||||
|
} else {
|
||||||
|
return g->addOp<ReduceMeanObj>(std::move(data), reduced, axes, keepdims)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -82,6 +82,10 @@ void init_graph_builder(py::module &m) {
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("gather",
|
.def("gather",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
|
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
|
||||||
|
policy::move)
|
||||||
|
.def("reduceMean",
|
||||||
|
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
||||||
|
bool>(&Handler::reduceMean),
|
||||||
policy::move);
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,7 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const optional<const vector<int>> &_axes,
|
const optional<vector<int>> &_axes, bool keepDims)
|
||||||
bool keepDims)
|
|
||||||
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
||||||
const auto size = input->getDims().size();
|
const auto size = input->getDims().size();
|
||||||
if (_axes) {
|
if (_axes) {
|
||||||
|
|
Loading…
Reference in New Issue