forked from jiuyuan/InfiniTensor
feat: 增加 add sub mul div pow 前端
- 添加每个算子的单元测试 - 添加线性回归模型导入测试 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
296fcc5aa0
commit
6e5beceadd
|
@ -39,6 +39,12 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||||
Tensor bias, ActType act);
|
Tensor bias, ActType act);
|
||||||
|
|
||||||
|
Tensor add(Tensor a, Tensor b, Tensor c);
|
||||||
|
Tensor sub(Tensor a, Tensor b, Tensor c);
|
||||||
|
Tensor mul(Tensor a, Tensor b, Tensor c);
|
||||||
|
Tensor div(Tensor a, Tensor b, Tensor c);
|
||||||
|
Tensor pow(Tensor a, Tensor b, Tensor c);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -19,15 +19,45 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
|
|
||||||
for node in model.graph.node:
|
for node in model.graph.node:
|
||||||
if node.op_type == "MatMul":
|
if node.op_type == "MatMul":
|
||||||
handler.matmul(
|
tensors[node.output[0]] = handler.matmul(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors[node.output[0]],
|
tensors.get(node.output[0], None),
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
backend.ActType.Linear,
|
backend.ActType.Linear,
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Add":
|
||||||
|
tensors[node.output[0]] = handler.add(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Sub":
|
||||||
|
tensors[node.output[0]] = handler.sub(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Mul":
|
||||||
|
tensors[node.output[0]] = handler.mul(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Div":
|
||||||
|
tensors[node.output[0]] = handler.div(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Pow":
|
||||||
|
tensors[node.output[0]] = handler.pow(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0], None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: onnx.ModelProto):
|
def parse_onnx(model: onnx.ModelProto):
|
||||||
|
|
|
@ -18,12 +18,82 @@ class TestStringMethods(unittest.TestCase):
|
||||||
)
|
)
|
||||||
parse_onnx(onnx.load(model_file))
|
parse_onnx(onnx.load(model_file))
|
||||||
|
|
||||||
def test_import(self):
|
def test_tensor(self):
|
||||||
i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
w = make_tensor_value_info("w", TensorProto.FLOAT, [1, 3, 4])
|
graph = make_graph([], "tensor", [x], [x])
|
||||||
o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 4])
|
model = make_model(graph)
|
||||||
matmul = make_node("MatMul", ["i", "w"], ["o"], name="matmul")
|
check_model(model)
|
||||||
graph = make_graph([matmul], "mm", [i, w], [o])
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_matmul(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||||
|
xa = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||||
|
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||||
|
graph = make_graph([matmul], "matmul", [x, a], [xa])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_add(self):
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
add = make_node("Add", ["a", "b"], ["c"], name="add")
|
||||||
|
graph = make_graph([add], "add", [a, b], [c])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_sub(self):
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
sub = make_node("Sub", ["a", "b"], ["c"], name="sub")
|
||||||
|
graph = make_graph([sub], "sub", [a, b], [c])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_mul(self):
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
mul = make_node("Mul", ["a", "b"], ["c"], name="mul")
|
||||||
|
graph = make_graph([mul], "mul", [a, b], [c])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_div(self):
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
div = make_node("Div", ["a", "b"], ["c"], name="div")
|
||||||
|
graph = make_graph([div], "div", [a, b], [c])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
def test_pow(self):
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
c = make_tensor_value_info("c", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
pow = make_node("Pow", ["a", "b"], ["c"], name="pow")
|
||||||
|
graph = make_graph([pow], "pow", [a, b], [c])
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
from_onnx(model)
|
||||||
|
|
||||||
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
|
def test_linear(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||||
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||||
|
y = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||||
|
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||||
|
add = make_node("Add", ["xa", "b"], ["y"], name="add")
|
||||||
|
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||||
model = make_model(graph)
|
model = make_model(graph)
|
||||||
check_model(model)
|
check_model(model)
|
||||||
print(model)
|
print(model)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -23,6 +24,22 @@ Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
||||||
|
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
||||||
|
if (c) { \
|
||||||
|
g->addOpWithOutputs<obj##Obj>(a, b, c); \
|
||||||
|
return c; \
|
||||||
|
} else { \
|
||||||
|
return g->addOp<obj##Obj>(a, b, c)->getOutput(); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_ELEMENT_WISE_METHOD(add, Add)
|
||||||
|
DEFINE_ELEMENT_WISE_METHOD(sub, Sub)
|
||||||
|
DEFINE_ELEMENT_WISE_METHOD(mul, Mul)
|
||||||
|
DEFINE_ELEMENT_WISE_METHOD(div, Div)
|
||||||
|
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -43,6 +43,21 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("matmul",
|
.def("matmul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||||
ActType>(&GraphHandlerObj::matmul),
|
ActType>(&GraphHandlerObj::matmul),
|
||||||
|
policy::reference_internal)
|
||||||
|
.def("add",
|
||||||
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
|
||||||
|
policy::reference_internal)
|
||||||
|
.def("sub",
|
||||||
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
|
||||||
|
policy::reference_internal)
|
||||||
|
.def("mul",
|
||||||
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
|
||||||
|
policy::reference_internal)
|
||||||
|
.def("div",
|
||||||
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
|
||||||
|
policy::reference_internal)
|
||||||
|
.def("pow",
|
||||||
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
|
||||||
policy::reference_internal);
|
policy::reference_internal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue