feat: 前端支持 identity 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-13 12:13:01 +08:00
parent 7f0c8ebae3
commit e4ec9c4230
5 changed files with 56 additions and 23 deletions

View File

@ -51,6 +51,7 @@ class GraphHandlerObj {
Tensor tanh(Tensor x, Tensor y);
Tensor softmax(Tensor x, Tensor y);
Tensor abs(Tensor x, Tensor y);
Tensor identity(Tensor x, Tensor y);
};
} // namespace infini

View File

@ -58,6 +58,38 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.input[1]],
tensors.get(node.output[0], None),
)
elif node.op_type == "Relu":
tensors[node.output[0]] = handler.relu(
tensors[node.input[0]],
tensors.get(node.output[0], None),
)
elif node.op_type == "Sigmoid":
tensors[node.output[0]] = handler.sigmoid(
tensors[node.input[0]],
tensors.get(node.output[0], None),
)
elif node.op_type == "Tanh":
tensors[node.output[0]] = handler.tanh(
tensors[node.input[0]],
tensors.get(node.output[0], None),
)
elif node.op_type == "Softmax":
tensors[node.output[0]] = handler.softmax(
tensors[node.input[0]],
tensors.get(node.output[0], None),
)
elif node.op_type == "Abs":
tensors[node.output[0]] = handler.abs(
tensors[node.input[0]],
tensors.get(node.output[0], None),
)
elif node.op_type == "Identity":
tensors[node.output[0]] = handler.identity(
tensors[node.input[0]],
tensors.get(node.output[0], None),
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
def parse_onnx(model: onnx.ModelProto):

View File

@ -135,7 +135,7 @@ class TestStringMethods(unittest.TestCase):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
y = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4])
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
add = make_node("Add", ["xa", "b"], ["y"], name="add")
graph = make_graph([matmul, add], "lr", [x, a, b], [y])

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/reshape.h"
#include "operators/unary.h"
namespace infini {
@ -58,6 +59,8 @@ DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
DEFINE_UNARY_METHOD(tanh, Tanh)
DEFINE_UNARY_METHOD(softmax, Softmax)
DEFINE_UNARY_METHOD(abs, Abs)
// see operators/reshape.h
DEFINE_UNARY_METHOD(identity, Identity)
static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) {

View File

@ -22,6 +22,8 @@ void register_operator_timer(py::module &m) {
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
@ -36,40 +38,35 @@ void init_graph_builder(py::module &m) {
.value("Tanh", ActType::Tanh)
.export_values();
py::class_<GraphHandler>(m, "GraphHandler");
py::class_<GraphHandlerObj>(m, "GraphHandlerObj")
py::class_<Handler>(m, "GraphHandlerObj")
.def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&GraphHandlerObj::tensor),
policy::reference_internal)
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
policy::move)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&GraphHandlerObj::matmul),
ActType>(&Handler::matmul),
policy::move)
.def("add",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
policy::move)
.def("sub",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
policy::move)
.def("mul",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
policy::move)
.def("div",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
policy::move)
.def("pow",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
policy::move)
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::relu),
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
policy::move)
.def("sigmoid",
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::sigmoid),
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
policy::move)
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::tanh),
policy::reference_internal)
.def("softmax",
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::softmax),
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
policy::move)
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::abs),
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
policy::move)
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
policy::move)
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
policy::move);
}