feat: 导出 OpType,为节点命名

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-17 15:58:26 +08:00
parent fe81fccf76
commit f2591edbb4
2 changed files with 65 additions and 12 deletions

View File

@ -309,6 +309,15 @@ def to_onnx(graph: backend.GraphHandler):
ops = graph.operators() ops = graph.operators()
names: Dict[Any, str] = dict()
count: Dict[backend.OpType, int] = dict()
for op in ops:
ty = op.op_type()
names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1)
count[ty] += 1
print(names)
def parse_onnx(model: onnx.ModelProto): def parse_onnx(model: onnx.ModelProto):
print() print()

View File

@ -21,6 +21,55 @@ void register_operator_timer(py::module &m) {
#endif #endif
} }
void init_values(py::module &m) {
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
py::enum_<ActType>(m, "ActType")
.value("Linear", ActType::None) // None 是 Python 关键字,不能用
.VALUE(ActType, Relu)
.VALUE(ActType, Sigmoid)
.VALUE(ActType, Tanh)
.export_values();
py::enum_<OpType>(m, "OpType")
.VALUE(OpType, Unknown)
.VALUE(OpType, Conv)
.VALUE(OpType, Matmul)
.VALUE(OpType, ConvTrans)
.VALUE(OpType, G2BMM)
.VALUE(OpType, GBMM)
.VALUE(OpType, Pad)
.VALUE(OpType, Slice)
.VALUE(OpType, Concat)
.VALUE(OpType, Split)
.VALUE(OpType, Transpose)
.VALUE(OpType, Extend)
.VALUE(OpType, MaxPool)
.VALUE(OpType, AvgPool)
.VALUE(OpType, Add)
.VALUE(OpType, Sub)
.VALUE(OpType, Mul)
.VALUE(OpType, Div)
.VALUE(OpType, Pow)
.VALUE(OpType, Gather)
.VALUE(OpType, ReduceMean)
.VALUE(OpType, Reshape)
.VALUE(OpType, Flatten)
.VALUE(OpType, Identity)
.VALUE(OpType, BatchNorm)
.VALUE(OpType, Softmax)
.VALUE(OpType, Activation)
.VALUE(OpType, Relu)
.VALUE(OpType, Sigmoid)
.VALUE(OpType, Tanh)
.VALUE(OpType, Abs)
.VALUE(OpType, Resize)
.VALUE(OpType, MemBound)
.export_values();
#undef VALUE
}
void init_graph_builder(py::module &m) { void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj; using Handler = GraphHandlerObj;
@ -28,14 +77,10 @@ void init_graph_builder(py::module &m) {
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime"); py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>( py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime"); m, "CpuRuntime");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj"); py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator"); .def("src", &TensorObj::getOutputOf, policy::move);
py::enum_<ActType>(m, "ActType") py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.value("Linear", ActType::None) // `None` is Python keyword .def("op_type", &OperatorObj::getOpType, policy::move);
.value("Relu", ActType::Relu)
.value("Sigmoid", ActType::Sigmoid)
.value("Tanh", ActType::Tanh)
.export_values();
py::class_<Handler>(m, "GraphHandler") py::class_<Handler>(m, "GraphHandler")
.def(py::init<Runtime>()) .def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor), .def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
@ -104,10 +149,8 @@ void init_graph_builder(py::module &m) {
py::overload_cast<Tensor, Tensor, const vector<int> &, py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<vector<int>> &>(&Handler::pad), const optional<vector<int>> &>(&Handler::pad),
policy::move) policy::move)
.def("topo_sort", py::overload_cast<>(&Handler::topo_sort), .def("topo_sort", &Handler::topo_sort, policy::automatic)
policy::automatic) .def("operators", &Handler::operators, policy::move)
.def("operators", py::overload_cast<>(&Handler::operators),
policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("run", &Handler::run, policy::automatic);
} }
@ -116,5 +159,6 @@ void init_graph_builder(py::module &m) {
PYBIND11_MODULE(backend, m) { PYBIND11_MODULE(backend, m) {
infini::register_operator_timer(m); infini::register_operator_timer(m);
infini::init_values(m);
infini::init_graph_builder(m); infini::init_graph_builder(m);
} }