feat: 前端支持 gather 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 14:16:01 +08:00
parent 45aa0237da
commit d11fb0ad5f
7 changed files with 40 additions and 5 deletions

View File

@ -59,6 +59,7 @@ class GraphHandlerObj {
Tensor flatten(Tensor s, Tensor y); Tensor flatten(Tensor s, Tensor y);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor concat(TensorVec inputs, Tensor output, int dim);
Tensor gather(Tensor input, Tensor indices, Tensor output, int axis);
}; };
} // namespace infini } // namespace infini

View File

@ -17,11 +17,11 @@ class GatherObj : public OperatorObj {
* *
* @param graph The computation graph that this operator belongs to. * @param graph The computation graph that this operator belongs to.
* @param input The input tensor. * @param input The input tensor.
* @param index The index tensor. * @param indices The index tensor.
* @param output The output tensor. * @param output The output tensor.
* @param axis The axis to gather on. * @param axis The axis to gather on.
*/ */
GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output, GatherObj(GraphObj *graph, Tensor input, Tensor indices, Tensor output,
int axis); int axis);
OP_CLONE(GatherObj); OP_CLONE(GatherObj);
std::string toString() const override; std::string toString() const override;

View File

@ -128,6 +128,13 @@ def from_onnx(model: onnx.ModelProto):
tensors.get(node.output[0]), tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")), next((attr.i for attr in node.attribute if attr.name == "axis")),
) )
elif node.op_type == "Gather":
tensors[node.output[0]] = handler.gather(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))

View File

@ -162,6 +162,15 @@ class TestStringMethods(unittest.TestCase):
make_graph([concat], "concat", [input1, input2], [output]) make_graph([concat], "concat", [input1, input2], [output])
) )
def test_gather(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [1, 3, 4, 4])
indices = make_tensor_value_info("indices", TensorProto.FLOAT, [2, 1, 2])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 2, 4, 4])
gather = make_node(
"Gather", ["data", "indices"], ["output"], axis=1, name="gather"
)
make_and_import_model(make_graph([gather], "gather", [data, indices], [output]))
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression> # see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self): def test_linear(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])

View File

@ -2,6 +2,7 @@
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/gather.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include "operators/unary.h" #include "operators/unary.h"
@ -103,6 +104,20 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
} }
} }
Tensor GraphHandlerObj::gather(Tensor input, Tensor indices, Tensor output,
int axis) {
if (output) {
g->addOpWithOutputs<GatherObj>(std::move(input), std::move(indices),
output, axis);
return output;
} else {
return g
->addOp<GatherObj>(std::move(input), std::move(indices), output,
axis)
->getOutput();
}
}
static DataType dtype_repr_convert(int dtype) { static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) { switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT: case OnnxDType::FLOAT:

View File

@ -79,6 +79,9 @@ void init_graph_builder(py::module &m) {
policy::move) policy::move)
.def("concat", .def("concat",
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat), py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
policy::move)
.def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>(&Handler::gather),
policy::move); policy::move);
} }

View File

@ -1,9 +1,9 @@
#include "operators/gather.h" #include "operators/gather.h"
namespace infini { namespace infini {
GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output, GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
int axis) Tensor output, int axis)
: OperatorObj(OpType::Gather, {input, index}, {output}), axis(axis) { : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }