forked from jiuyuan/InfiniTensor
feat: 前端支持 identity 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
7f0c8ebae3
commit
e4ec9c4230
|
@ -51,6 +51,7 @@ class GraphHandlerObj {
|
|||
Tensor tanh(Tensor x, Tensor y);
|
||||
Tensor softmax(Tensor x, Tensor y);
|
||||
Tensor abs(Tensor x, Tensor y);
|
||||
Tensor identity(Tensor x, Tensor y);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -58,6 +58,38 @@ def from_onnx(model: onnx.ModelProto):
|
|||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
elif node.op_type == "Relu":
|
||||
tensors[node.output[0]] = handler.relu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
elif node.op_type == "Sigmoid":
|
||||
tensors[node.output[0]] = handler.sigmoid(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
elif node.op_type == "Tanh":
|
||||
tensors[node.output[0]] = handler.tanh(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
elif node.op_type == "Softmax":
|
||||
tensors[node.output[0]] = handler.softmax(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
elif node.op_type == "Abs":
|
||||
tensors[node.output[0]] = handler.abs(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
elif node.op_type == "Identity":
|
||||
tensors[node.output[0]] = handler.identity(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0], None),
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
|
||||
def parse_onnx(model: onnx.ModelProto):
|
||||
|
|
|
@ -135,7 +135,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||
y = make_tensor_value_info("b", TensorProto.FLOAT, [1, 2, 4])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 2, 4])
|
||||
matmul = make_node("MatMul", ["x", "a"], ["xa"], name="matmul")
|
||||
add = make_node("Add", ["xa", "b"], ["y"], name="add")
|
||||
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -58,6 +59,8 @@ DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
|||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||
DEFINE_UNARY_METHOD(softmax, Softmax)
|
||||
DEFINE_UNARY_METHOD(abs, Abs)
|
||||
// see operators/reshape.h
|
||||
DEFINE_UNARY_METHOD(identity, Identity)
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
|
|
|
@ -22,6 +22,8 @@ void register_operator_timer(py::module &m) {
|
|||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
|
@ -36,40 +38,35 @@ void init_graph_builder(py::module &m) {
|
|||
.value("Tanh", ActType::Tanh)
|
||||
.export_values();
|
||||
py::class_<GraphHandler>(m, "GraphHandler");
|
||||
py::class_<GraphHandlerObj>(m, "GraphHandlerObj")
|
||||
py::class_<Handler>(m, "GraphHandlerObj")
|
||||
.def(py::init<Runtime>())
|
||||
.def("tensor", py::overload_cast<Shape, int>(&GraphHandlerObj::tensor),
|
||||
policy::reference_internal)
|
||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||
policy::move)
|
||||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&GraphHandlerObj::matmul),
|
||||
ActType>(&Handler::matmul),
|
||||
policy::move)
|
||||
.def("add",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::add),
|
||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||
policy::move)
|
||||
.def("sub",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::sub),
|
||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||
policy::move)
|
||||
.def("mul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::mul),
|
||||
.def("mul", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::mul),
|
||||
policy::move)
|
||||
.def("div",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::div),
|
||||
.def("div", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::div),
|
||||
policy::move)
|
||||
.def("pow",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphHandlerObj::pow),
|
||||
.def("pow", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::pow),
|
||||
policy::move)
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::relu),
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&Handler::relu),
|
||||
policy::move)
|
||||
.def("sigmoid",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::sigmoid),
|
||||
.def("sigmoid", py::overload_cast<Tensor, Tensor>(&Handler::sigmoid),
|
||||
policy::move)
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::tanh),
|
||||
policy::reference_internal)
|
||||
.def("softmax",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::softmax),
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&Handler::tanh),
|
||||
policy::move)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphHandlerObj::abs),
|
||||
.def("softmax", py::overload_cast<Tensor, Tensor>(&Handler::softmax),
|
||||
policy::move)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&Handler::abs),
|
||||
policy::move)
|
||||
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
||||
policy::move);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue