forked from jiuyuan/InfiniTensor
feat: 前端支持 flatten 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
e4ec9c4230
commit
e194dd943b
|
@ -52,6 +52,7 @@ class GraphHandlerObj {
|
||||||
Tensor softmax(Tensor x, Tensor y);
|
Tensor softmax(Tensor x, Tensor y);
|
||||||
Tensor abs(Tensor x, Tensor y);
|
Tensor abs(Tensor x, Tensor y);
|
||||||
Tensor identity(Tensor x, Tensor y);
|
Tensor identity(Tensor x, Tensor y);
|
||||||
|
Tensor flatten(Tensor s, Tensor y);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -88,6 +88,16 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0], None),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Flatten":
|
||||||
|
# TODO 后端算子不支持沿任意轴展开
|
||||||
|
axis = next(
|
||||||
|
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
||||||
|
)
|
||||||
|
assert axis == None or axis == 1
|
||||||
|
tensors[node.output[0]] = handler.flatten(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,12 @@ from onnx.checker import check_model
|
||||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
||||||
|
|
||||||
|
|
||||||
|
def make_and_import_model(graph: onnx.GraphProto):
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
|
||||||
class TestStringMethods(unittest.TestCase):
|
class TestStringMethods(unittest.TestCase):
|
||||||
def test_load(self):
|
def test_load(self):
|
||||||
model_file = next(
|
model_file = next(
|
||||||
|
@ -20,115 +26,91 @@ class TestStringMethods(unittest.TestCase):
|
||||||
|
|
||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
graph = make_graph([], "tensor", [x], [x])
|
make_and_import_model(make_graph([], "tensor", [x], [x]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_matmul(self):
|
def test_matmul(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||||
xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||||
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||||
graph = make_graph([matmul], "matmul", [x, a], [xa])
|
make_and_import_model(make_graph([matmul], "matmul", [x, a], [xa]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
add = make_node("Add", ["a", "b"], ["c"], name="add")
|
add = make_node("Add", ["a", "b"], ["c"], name="add")
|
||||||
graph = make_graph([add], "add", [a, b], [c])
|
make_and_import_model(make_graph([add], "add", [a, b], [c]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_sub(self):
|
def test_sub(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
sub = make_node("Sub", ["a", "b"], ["c"], name="sub")
|
sub = make_node("Sub", ["a", "b"], ["c"], name="sub")
|
||||||
graph = make_graph([sub], "sub", [a, b], [c])
|
make_and_import_model(make_graph([sub], "sub", [a, b], [c]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_mul(self):
|
def test_mul(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
mul = make_node("Mul", ["a", "b"], ["c"], name="mul")
|
mul = make_node("Mul", ["a", "b"], ["c"], name="mul")
|
||||||
graph = make_graph([mul], "mul", [a, b], [c])
|
make_and_import_model(make_graph([mul], "mul", [a, b], [c]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_div(self):
|
def test_div(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
div = make_node("Div", ["a", "b"], ["c"], name="div")
|
div = make_node("Div", ["a", "b"], ["c"], name="div")
|
||||||
graph = make_graph([div], "div", [a, b], [c])
|
make_and_import_model(make_graph([div], "div", [a, b], [c]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_pow(self):
|
def test_pow(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
pow = make_node("Pow", ["a", "b"], ["c"], name="pow")
|
pow = make_node("Pow", ["a", "b"], ["c"], name="pow")
|
||||||
graph = make_graph([pow], "pow", [a, b], [c])
|
make_and_import_model(make_graph([pow], "pow", [a, b], [c]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_relu(self):
|
def test_relu(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
||||||
graph = make_graph([relu], "relu", [x], [y])
|
make_and_import_model(make_graph([relu], "relu", [x], [y]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_sigmoid(self):
|
def test_sigmoid(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid")
|
sigmoid = make_node("Sigmoid", ["x"], ["y"], name="sigmoid")
|
||||||
graph = make_graph([sigmoid], "sigmoid", [x], [y])
|
make_and_import_model(make_graph([sigmoid], "sigmoid", [x], [y]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_tanh(self):
|
def test_tanh(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
||||||
graph = make_graph([tanh], "tanh", [x], [y])
|
make_and_import_model(make_graph([tanh], "tanh", [x], [y]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_softmax(self):
|
def test_softmax(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
softmax = make_node("Softmax", ["x"], ["y"], name="softmax")
|
softmax = make_node("Softmax", ["x"], ["y"], name="softmax")
|
||||||
graph = make_graph([softmax], "softmax", [x], [y])
|
make_and_import_model(make_graph([softmax], "softmax", [x], [y]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
|
||||||
from_onnx(model)
|
|
||||||
|
|
||||||
def test_abs(self):
|
def test_abs(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
abs = make_node("Abs", ["x"], ["y"], name="abs")
|
abs = make_node("Abs", ["x"], ["y"], name="abs")
|
||||||
graph = make_graph([abs], "abs", [x], [y])
|
make_and_import_model(make_graph([abs], "abs", [x], [y]))
|
||||||
model = make_model(graph)
|
|
||||||
check_model(model)
|
def test_identity(self):
|
||||||
from_onnx(model)
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
identity = make_node("Identity", ["x"], ["y"], name="identity")
|
||||||
|
make_and_import_model(make_graph([identity], "identity", [x], [y]))
|
||||||
|
|
||||||
|
def test_flatten(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3 * 5 * 7])
|
||||||
|
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
||||||
|
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
|
@ -141,8 +123,6 @@ class TestStringMethods(unittest.TestCase):
|
||||||
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||||
model = make_model(graph)
|
model = make_model(graph)
|
||||||
check_model(model)
|
check_model(model)
|
||||||
print(model)
|
|
||||||
|
|
||||||
from_onnx(model)
|
from_onnx(model)
|
||||||
parse_onnx(model)
|
parse_onnx(model)
|
||||||
|
|
||||||
|
|
|
@ -61,6 +61,7 @@ DEFINE_UNARY_METHOD(softmax, Softmax)
|
||||||
DEFINE_UNARY_METHOD(abs, Abs)
|
DEFINE_UNARY_METHOD(abs, Abs)
|
||||||
// see operators/reshape.h
|
// see operators/reshape.h
|
||||||
DEFINE_UNARY_METHOD(identity, Identity)
|
DEFINE_UNARY_METHOD(identity, Identity)
|
||||||
|
DEFINE_UNARY_METHOD(flatten, Flatten)
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
|
|
|
@ -67,6 +67,8 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
||||||
|
policy::move)
|
||||||
|
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
|
||||||
policy::move);
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue