forked from jiuyuan/InfiniTensor
feat: 导出 OpType,为节点命名
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
fe81fccf76
commit
f2591edbb4
|
@ -309,6 +309,15 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
|
|
||||||
ops = graph.operators()
|
ops = graph.operators()
|
||||||
|
|
||||||
|
names: Dict[Any, str] = dict()
|
||||||
|
count: Dict[backend.OpType, int] = dict()
|
||||||
|
|
||||||
|
for op in ops:
|
||||||
|
ty = op.op_type()
|
||||||
|
names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1)
|
||||||
|
count[ty] += 1
|
||||||
|
print(names)
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: onnx.ModelProto):
|
def parse_onnx(model: onnx.ModelProto):
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -21,6 +21,55 @@ void register_operator_timer(py::module &m) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void init_values(py::module &m) {
|
||||||
|
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
|
||||||
|
|
||||||
|
py::enum_<ActType>(m, "ActType")
|
||||||
|
.value("Linear", ActType::None) // None 是 Python 关键字,不能用
|
||||||
|
.VALUE(ActType, Relu)
|
||||||
|
.VALUE(ActType, Sigmoid)
|
||||||
|
.VALUE(ActType, Tanh)
|
||||||
|
.export_values();
|
||||||
|
|
||||||
|
py::enum_<OpType>(m, "OpType")
|
||||||
|
.VALUE(OpType, Unknown)
|
||||||
|
.VALUE(OpType, Conv)
|
||||||
|
.VALUE(OpType, Matmul)
|
||||||
|
.VALUE(OpType, ConvTrans)
|
||||||
|
.VALUE(OpType, G2BMM)
|
||||||
|
.VALUE(OpType, GBMM)
|
||||||
|
.VALUE(OpType, Pad)
|
||||||
|
.VALUE(OpType, Slice)
|
||||||
|
.VALUE(OpType, Concat)
|
||||||
|
.VALUE(OpType, Split)
|
||||||
|
.VALUE(OpType, Transpose)
|
||||||
|
.VALUE(OpType, Extend)
|
||||||
|
.VALUE(OpType, MaxPool)
|
||||||
|
.VALUE(OpType, AvgPool)
|
||||||
|
.VALUE(OpType, Add)
|
||||||
|
.VALUE(OpType, Sub)
|
||||||
|
.VALUE(OpType, Mul)
|
||||||
|
.VALUE(OpType, Div)
|
||||||
|
.VALUE(OpType, Pow)
|
||||||
|
.VALUE(OpType, Gather)
|
||||||
|
.VALUE(OpType, ReduceMean)
|
||||||
|
.VALUE(OpType, Reshape)
|
||||||
|
.VALUE(OpType, Flatten)
|
||||||
|
.VALUE(OpType, Identity)
|
||||||
|
.VALUE(OpType, BatchNorm)
|
||||||
|
.VALUE(OpType, Softmax)
|
||||||
|
.VALUE(OpType, Activation)
|
||||||
|
.VALUE(OpType, Relu)
|
||||||
|
.VALUE(OpType, Sigmoid)
|
||||||
|
.VALUE(OpType, Tanh)
|
||||||
|
.VALUE(OpType, Abs)
|
||||||
|
.VALUE(OpType, Resize)
|
||||||
|
.VALUE(OpType, MemBound)
|
||||||
|
.export_values();
|
||||||
|
|
||||||
|
#undef VALUE
|
||||||
|
}
|
||||||
|
|
||||||
void init_graph_builder(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
using Handler = GraphHandlerObj;
|
using Handler = GraphHandlerObj;
|
||||||
|
|
||||||
|
@ -28,14 +77,10 @@ void init_graph_builder(py::module &m) {
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
m, "CpuRuntime");
|
m, "CpuRuntime");
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
|
||||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator");
|
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||||
py::enum_<ActType>(m, "ActType")
|
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||||
.value("Linear", ActType::None) // `None` is Python keyword
|
.def("op_type", &OperatorObj::getOpType, policy::move);
|
||||||
.value("Relu", ActType::Relu)
|
|
||||||
.value("Sigmoid", ActType::Sigmoid)
|
|
||||||
.value("Tanh", ActType::Tanh)
|
|
||||||
.export_values();
|
|
||||||
py::class_<Handler>(m, "GraphHandler")
|
py::class_<Handler>(m, "GraphHandler")
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||||
|
@ -104,10 +149,8 @@ void init_graph_builder(py::module &m) {
|
||||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||||
const optional<vector<int>> &>(&Handler::pad),
|
const optional<vector<int>> &>(&Handler::pad),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("topo_sort", py::overload_cast<>(&Handler::topo_sort),
|
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||||
policy::automatic)
|
.def("operators", &Handler::operators, policy::move)
|
||||||
.def("operators", py::overload_cast<>(&Handler::operators),
|
|
||||||
policy::move)
|
|
||||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||||
.def("run", &Handler::run, policy::automatic);
|
.def("run", &Handler::run, policy::automatic);
|
||||||
}
|
}
|
||||||
|
@ -116,5 +159,6 @@ void init_graph_builder(py::module &m) {
|
||||||
|
|
||||||
PYBIND11_MODULE(backend, m) {
|
PYBIND11_MODULE(backend, m) {
|
||||||
infini::register_operator_timer(m);
|
infini::register_operator_timer(m);
|
||||||
|
infini::init_values(m);
|
||||||
infini::init_graph_builder(m);
|
infini::init_graph_builder(m);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue