forked from jiuyuan/InfiniTensor
feat: 前端支持 concat 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
a7e58bd8d0
commit
45aa0237da
|
@ -58,6 +58,7 @@ class GraphHandlerObj {
|
||||||
Tensor identity(Tensor x, Tensor y);
|
Tensor identity(Tensor x, Tensor y);
|
||||||
Tensor flatten(Tensor s, Tensor y);
|
Tensor flatten(Tensor s, Tensor y);
|
||||||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||||
|
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import onnx, backend
|
import onnx, backend
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
@ -6,8 +7,8 @@ runtime = backend.cpu_runtime()
|
||||||
def from_onnx(model: onnx.ModelProto):
|
def from_onnx(model: onnx.ModelProto):
|
||||||
handler = backend.GraphHandlerObj(runtime)
|
handler = backend.GraphHandlerObj(runtime)
|
||||||
|
|
||||||
tensors = dict()
|
tensors: Dict[str, backend.TensorObj] = dict()
|
||||||
data = dict()
|
data: Dict[str, onnx.TensorProto] = dict()
|
||||||
|
|
||||||
for input in model.graph.input:
|
for input in model.graph.input:
|
||||||
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
||||||
|
@ -121,6 +122,12 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
[int(i) for i in data[node.input[1]].int64_data],
|
[int(i) for i in data[node.input[1]].int64_data],
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Concat":
|
||||||
|
tensors[node.output[0]] = handler.concat(
|
||||||
|
[tensors[name] for name in node.input],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
|
@ -136,11 +136,11 @@ class TestStringMethods(unittest.TestCase):
|
||||||
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||||
|
|
||||||
def test_reshape(self):
|
def test_reshape(self):
|
||||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||||
# shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨
|
# shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨
|
||||||
# 不知道怎么把 ValueInfoProto 转换成 TensorProto
|
# 不知道怎么把 ValueInfoProto 转换成 TensorProto
|
||||||
shape = make_tensor_value_info("shape", TensorProto.INT64, [4])
|
shape = make_tensor_value_info("shape", TensorProto.INT64, [3])
|
||||||
shape_data = make_tensor("shape", TensorProto.INT64, [4], [3, 2, 4, 3])
|
shape_data = make_tensor("shape", TensorProto.INT64, [3], [5, 3, 8])
|
||||||
reshaped = make_tensor_value_info(
|
reshaped = make_tensor_value_info(
|
||||||
"reshaped", TensorProto.FLOAT, shape_data.int64_data
|
"reshaped", TensorProto.FLOAT, shape_data.int64_data
|
||||||
)
|
)
|
||||||
|
@ -151,6 +151,17 @@ class TestStringMethods(unittest.TestCase):
|
||||||
make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data])
|
make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_concat(self):
|
||||||
|
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
|
input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 9])
|
||||||
|
concat = make_node(
|
||||||
|
"Concat", ["input1", "input2"], ["output"], axis=3, name="concat"
|
||||||
|
)
|
||||||
|
make_and_import_model(
|
||||||
|
make_graph([concat], "concat", [input1, input2], [output])
|
||||||
|
)
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
|
#include "operators/concat.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
@ -93,6 +94,15 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<ConcatObj>(std::move(inputs), output, dim);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<ConcatObj>(std::move(inputs), output, dim)->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -76,6 +76,9 @@ void init_graph_builder(py::module &m) {
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("reshape",
|
.def("reshape",
|
||||||
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
|
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
|
||||||
|
policy::move)
|
||||||
|
.def("concat",
|
||||||
|
py::overload_cast<TensorVec, Tensor, int>(&Handler::concat),
|
||||||
policy::move);
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue