forked from jiuyuan/InfiniTensor
feat: 导出 OperatorObj
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
45a3cdfa30
commit
fe81fccf76
|
@ -38,12 +38,14 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
Tensor tensor(Shape dims, int dtype);
|
Tensor tensor(Shape dims, int dtype);
|
||||||
|
|
||||||
|
//------ operators
|
||||||
|
|
||||||
|
inline OpVec operators() { return g->getOperators(); }
|
||||||
|
|
||||||
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||||
int sh, int sw, int dh, int dw);
|
int sh, int sw, int dh, int dw);
|
||||||
|
|
||||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||||
Tensor bias, ActType act);
|
Tensor bias, ActType act);
|
||||||
|
|
||||||
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
||||||
Tensor scale, Tensor bias, float momentum, float eps,
|
Tensor scale, Tensor bias, float momentum, float eps,
|
||||||
bool training);
|
bool training);
|
||||||
|
@ -77,6 +79,10 @@ class GraphHandlerObj {
|
||||||
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
||||||
const optional<vector<int>> &axes);
|
const optional<vector<int>> &axes);
|
||||||
|
|
||||||
|
//------ modifiers
|
||||||
|
|
||||||
|
inline bool topo_sort() { return g->topo_sort(); }
|
||||||
|
|
||||||
//------ runtime
|
//------ runtime
|
||||||
|
|
||||||
inline void data_malloc() { g->dataMalloc(); }
|
inline void data_malloc() { g->dataMalloc(); }
|
||||||
|
|
|
@ -6,11 +6,11 @@ from functools import reduce
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
|
||||||
def from_onnx(model: onnx.ModelProto):
|
def from_onnx(model: onnx.ModelProto) -> backend.GraphHandler:
|
||||||
model = infer_shapes(model)
|
model = infer_shapes(model)
|
||||||
handler = backend.GraphHandlerObj(runtime)
|
handler = backend.GraphHandler(runtime)
|
||||||
|
|
||||||
tensors: Dict[str, backend.TensorObj] = dict()
|
tensors: Dict[str, backend.Tensor] = dict()
|
||||||
data: Dict[str, onnx.TensorProto] = dict()
|
data: Dict[str, onnx.TensorProto] = dict()
|
||||||
|
|
||||||
for input in model.graph.input:
|
for input in model.graph.input:
|
||||||
|
@ -303,6 +303,13 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
||||||
|
def to_onnx(graph: backend.GraphHandler):
|
||||||
|
if not graph.topo_sort():
|
||||||
|
raise Exception("Sorting fails")
|
||||||
|
|
||||||
|
ops = graph.operators()
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: onnx.ModelProto):
|
def parse_onnx(model: onnx.ModelProto):
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from onnx.helper import (
|
||||||
make_tensor_value_info,
|
make_tensor_value_info,
|
||||||
)
|
)
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
|
||||||
|
|
||||||
|
|
||||||
def make_and_import_model(graph: onnx.GraphProto):
|
def make_and_import_model(graph: onnx.GraphProto):
|
||||||
|
@ -293,12 +293,14 @@ class TestStringMethods(unittest.TestCase):
|
||||||
parse_onnx(model)
|
parse_onnx(model)
|
||||||
|
|
||||||
def test_frontend(self):
|
def test_frontend(self):
|
||||||
handler = backend.GraphHandlerObj(runtime)
|
handler = backend.GraphHandler(runtime)
|
||||||
i = handler.tensor([1, 2, 3], 12)
|
i = handler.tensor([1, 2, 3], 12)
|
||||||
w = handler.tensor([1, 3, 4], 12)
|
w = handler.tensor([1, 3, 4], 12)
|
||||||
o = handler.tensor([1, 2, 4], 12)
|
o = handler.tensor([1, 2, 4], 12)
|
||||||
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
|
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
|
||||||
|
|
||||||
|
to_onnx(handler)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -25,17 +25,18 @@ void init_graph_builder(py::module &m) {
|
||||||
using Handler = GraphHandlerObj;
|
using Handler = GraphHandlerObj;
|
||||||
|
|
||||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
m, "CpuRuntimeObj");
|
m, "CpuRuntime");
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
||||||
|
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator");
|
||||||
py::enum_<ActType>(m, "ActType")
|
py::enum_<ActType>(m, "ActType")
|
||||||
.value("Linear", ActType::None) // `None` is Python keyword
|
.value("Linear", ActType::None) // `None` is Python keyword
|
||||||
.value("Relu", ActType::Relu)
|
.value("Relu", ActType::Relu)
|
||||||
.value("Sigmoid", ActType::Sigmoid)
|
.value("Sigmoid", ActType::Sigmoid)
|
||||||
.value("Tanh", ActType::Tanh)
|
.value("Tanh", ActType::Tanh)
|
||||||
.export_values();
|
.export_values();
|
||||||
py::class_<Handler>(m, "GraphHandlerObj")
|
py::class_<Handler>(m, "GraphHandler")
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||||
policy::move)
|
policy::move)
|
||||||
|
@ -103,6 +104,10 @@ void init_graph_builder(py::module &m) {
|
||||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||||
const optional<vector<int>> &>(&Handler::pad),
|
const optional<vector<int>> &>(&Handler::pad),
|
||||||
policy::move)
|
policy::move)
|
||||||
|
.def("topo_sort", py::overload_cast<>(&Handler::topo_sort),
|
||||||
|
policy::automatic)
|
||||||
|
.def("operators", py::overload_cast<>(&Handler::operators),
|
||||||
|
policy::move)
|
||||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||||
.def("run", &Handler::run, policy::automatic);
|
.def("run", &Handler::run, policy::automatic);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue