diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 832e398c..68d0f110 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -309,6 +309,15 @@ def to_onnx(graph: backend.GraphHandler): ops = graph.operators() + names: Dict[Any, str] = dict() + count: Dict[backend.OpType, int] = dict() + + for op in ops: + ty = op.op_type() + names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1) + count[ty] += 1 + print(names) + def parse_onnx(model: onnx.ModelProto): print() diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index b0de6d08..5a86b56a 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -21,6 +21,55 @@ void register_operator_timer(py::module &m) { #endif } +void init_values(py::module &m) { +#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME) + + py::enum_(m, "ActType") + .value("Linear", ActType::None) // None 是 Python 关键字,不能用 + .VALUE(ActType, Relu) + .VALUE(ActType, Sigmoid) + .VALUE(ActType, Tanh) + .export_values(); + + py::enum_(m, "OpType") + .VALUE(OpType, Unknown) + .VALUE(OpType, Conv) + .VALUE(OpType, Matmul) + .VALUE(OpType, ConvTrans) + .VALUE(OpType, G2BMM) + .VALUE(OpType, GBMM) + .VALUE(OpType, Pad) + .VALUE(OpType, Slice) + .VALUE(OpType, Concat) + .VALUE(OpType, Split) + .VALUE(OpType, Transpose) + .VALUE(OpType, Extend) + .VALUE(OpType, MaxPool) + .VALUE(OpType, AvgPool) + .VALUE(OpType, Add) + .VALUE(OpType, Sub) + .VALUE(OpType, Mul) + .VALUE(OpType, Div) + .VALUE(OpType, Pow) + .VALUE(OpType, Gather) + .VALUE(OpType, ReduceMean) + .VALUE(OpType, Reshape) + .VALUE(OpType, Flatten) + .VALUE(OpType, Identity) + .VALUE(OpType, BatchNorm) + .VALUE(OpType, Softmax) + .VALUE(OpType, Activation) + .VALUE(OpType, Relu) + .VALUE(OpType, Sigmoid) + .VALUE(OpType, Tanh) + .VALUE(OpType, Abs) + .VALUE(OpType, Resize) + .VALUE(OpType, MemBound) + .export_values(); + +#undef VALUE +} + void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; @@ -28,14 +77,10 @@ void init_graph_builder(py::module &m) { py::class_>(m, "Runtime"); py::class_, RuntimeObj>( m, "CpuRuntime"); - py::class_>(m, "TensorObj"); - py::class_>(m, "Operator"); - py::enum_(m, "ActType") - .value("Linear", ActType::None) // `None` is Python keyword - .value("Relu", ActType::Relu) - .value("Sigmoid", ActType::Sigmoid) - .value("Tanh", ActType::Tanh) - .export_values(); + py::class_>(m, "TensorObj") + .def("src", &TensorObj::getOutputOf, policy::move); + py::class_>(m, "Operator") + .def("op_type", &OperatorObj::getOpType, policy::move); py::class_(m, "GraphHandler") .def(py::init()) .def("tensor", py::overload_cast(&Handler::tensor), @@ -104,10 +149,8 @@ void init_graph_builder(py::module &m) { py::overload_cast &, const optional> &>(&Handler::pad), policy::move) - .def("topo_sort", py::overload_cast<>(&Handler::topo_sort), - policy::automatic) - .def("operators", py::overload_cast<>(&Handler::operators), - policy::move) + .def("topo_sort", &Handler::topo_sort, policy::automatic) + .def("operators", &Handler::operators, policy::move) .def("data_malloc", &Handler::data_malloc, policy::automatic) .def("run", &Handler::run, policy::automatic); } @@ -116,5 +159,6 @@ void init_graph_builder(py::module &m) { PYBIND11_MODULE(backend, m) { infini::register_operator_timer(m); + infini::init_values(m); infini::init_graph_builder(m); }