feat: 导出 OperatorObj

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-17 15:00:34 +08:00
parent 45a3cdfa30
commit fe81fccf76
4 changed files with 30 additions and 10 deletions

View File

@ -38,12 +38,14 @@ class GraphHandlerObj {
Tensor tensor(Shape dims, int dtype); Tensor tensor(Shape dims, int dtype);
//------ operators
inline OpVec operators() { return g->getOperators(); }
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw, Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
int sh, int sw, int dh, int dw); int sh, int sw, int dh, int dw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act); Tensor bias, ActType act);
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
Tensor scale, Tensor bias, float momentum, float eps, Tensor scale, Tensor bias, float momentum, float eps,
bool training); bool training);
@ -77,6 +79,10 @@ class GraphHandlerObj {
Tensor pad(Tensor input, Tensor output, const vector<int> &pads, Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes); const optional<vector<int>> &axes);
//------ modifiers
inline bool topo_sort() { return g->topo_sort(); }
//------ runtime //------ runtime
inline void data_malloc() { g->dataMalloc(); } inline void data_malloc() { g->dataMalloc(); }

View File

@ -6,11 +6,11 @@ from functools import reduce
runtime = backend.cpu_runtime() runtime = backend.cpu_runtime()
def from_onnx(model: onnx.ModelProto): def from_onnx(model: onnx.ModelProto) -> backend.GraphHandler:
model = infer_shapes(model) model = infer_shapes(model)
handler = backend.GraphHandlerObj(runtime) handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.TensorObj] = dict() tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, onnx.TensorProto] = dict() data: Dict[str, onnx.TensorProto] = dict()
for input in model.graph.input: for input in model.graph.input:
@ -303,6 +303,13 @@ def from_onnx(model: onnx.ModelProto):
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
def to_onnx(graph: backend.GraphHandler):
if not graph.topo_sort():
raise Exception("Sorting fails")
ops = graph.operators()
def parse_onnx(model: onnx.ModelProto): def parse_onnx(model: onnx.ModelProto):
print() print()

View File

@ -8,7 +8,7 @@ from onnx.helper import (
make_tensor_value_info, make_tensor_value_info,
) )
from onnx.checker import check_model from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
def make_and_import_model(graph: onnx.GraphProto): def make_and_import_model(graph: onnx.GraphProto):
@ -293,12 +293,14 @@ class TestStringMethods(unittest.TestCase):
parse_onnx(model) parse_onnx(model)
def test_frontend(self): def test_frontend(self):
handler = backend.GraphHandlerObj(runtime) handler = backend.GraphHandler(runtime)
i = handler.tensor([1, 2, 3], 12) i = handler.tensor([1, 2, 3], 12)
w = handler.tensor([1, 3, 4], 12) w = handler.tensor([1, 3, 4], 12)
o = handler.tensor([1, 2, 4], 12) o = handler.tensor([1, 2, 4], 12)
handler.matmul(i, w, o, False, False, None, backend.ActType.Relu) handler.matmul(i, w, o, False, False, None, backend.ActType.Relu)
to_onnx(handler)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -25,17 +25,18 @@ void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj; using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance); m.def("cpu_runtime", &CpuRuntimeObj::getInstance);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj"); py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>( py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj"); m, "CpuRuntime");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj"); py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator");
py::enum_<ActType>(m, "ActType") py::enum_<ActType>(m, "ActType")
.value("Linear", ActType::None) // `None` is Python keyword .value("Linear", ActType::None) // `None` is Python keyword
.value("Relu", ActType::Relu) .value("Relu", ActType::Relu)
.value("Sigmoid", ActType::Sigmoid) .value("Sigmoid", ActType::Sigmoid)
.value("Tanh", ActType::Tanh) .value("Tanh", ActType::Tanh)
.export_values(); .export_values();
py::class_<Handler>(m, "GraphHandlerObj") py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>()) .def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor), .def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
policy::move) policy::move)
@ -103,6 +104,10 @@ void init_graph_builder(py::module &m) {
py::overload_cast<Tensor, Tensor, const vector<int> &, py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<vector<int>> &>(&Handler::pad), const optional<vector<int>> &>(&Handler::pad),
policy::move) policy::move)
.def("topo_sort", py::overload_cast<>(&Handler::topo_sort),
policy::automatic)
.def("operators", py::overload_cast<>(&Handler::operators),
policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("run", &Handler::run, policy::automatic);
} }