forked from jiuyuan/InfiniTensor
feat: 导出 OpType,为节点命名
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
fe81fccf76
commit
f2591edbb4
|
@ -309,6 +309,15 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
|
||||
ops = graph.operators()
|
||||
|
||||
names: Dict[Any, str] = dict()
|
||||
count: Dict[backend.OpType, int] = dict()
|
||||
|
||||
for op in ops:
|
||||
ty = op.op_type()
|
||||
names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1)
|
||||
count[ty] += 1
|
||||
print(names)
|
||||
|
||||
|
||||
def parse_onnx(model: onnx.ModelProto):
|
||||
print()
|
||||
|
|
|
@ -21,6 +21,55 @@ void register_operator_timer(py::module &m) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void init_values(py::module &m) {
|
||||
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
|
||||
|
||||
py::enum_<ActType>(m, "ActType")
|
||||
.value("Linear", ActType::None) // None 是 Python 关键字,不能用
|
||||
.VALUE(ActType, Relu)
|
||||
.VALUE(ActType, Sigmoid)
|
||||
.VALUE(ActType, Tanh)
|
||||
.export_values();
|
||||
|
||||
py::enum_<OpType>(m, "OpType")
|
||||
.VALUE(OpType, Unknown)
|
||||
.VALUE(OpType, Conv)
|
||||
.VALUE(OpType, Matmul)
|
||||
.VALUE(OpType, ConvTrans)
|
||||
.VALUE(OpType, G2BMM)
|
||||
.VALUE(OpType, GBMM)
|
||||
.VALUE(OpType, Pad)
|
||||
.VALUE(OpType, Slice)
|
||||
.VALUE(OpType, Concat)
|
||||
.VALUE(OpType, Split)
|
||||
.VALUE(OpType, Transpose)
|
||||
.VALUE(OpType, Extend)
|
||||
.VALUE(OpType, MaxPool)
|
||||
.VALUE(OpType, AvgPool)
|
||||
.VALUE(OpType, Add)
|
||||
.VALUE(OpType, Sub)
|
||||
.VALUE(OpType, Mul)
|
||||
.VALUE(OpType, Div)
|
||||
.VALUE(OpType, Pow)
|
||||
.VALUE(OpType, Gather)
|
||||
.VALUE(OpType, ReduceMean)
|
||||
.VALUE(OpType, Reshape)
|
||||
.VALUE(OpType, Flatten)
|
||||
.VALUE(OpType, Identity)
|
||||
.VALUE(OpType, BatchNorm)
|
||||
.VALUE(OpType, Softmax)
|
||||
.VALUE(OpType, Activation)
|
||||
.VALUE(OpType, Relu)
|
||||
.VALUE(OpType, Sigmoid)
|
||||
.VALUE(OpType, Tanh)
|
||||
.VALUE(OpType, Abs)
|
||||
.VALUE(OpType, Resize)
|
||||
.VALUE(OpType, MemBound)
|
||||
.export_values();
|
||||
|
||||
#undef VALUE
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
|
@ -28,14 +77,10 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator");
|
||||
py::enum_<ActType>(m, "ActType")
|
||||
.value("Linear", ActType::None) // `None` is Python keyword
|
||||
.value("Relu", ActType::Relu)
|
||||
.value("Sigmoid", ActType::Sigmoid)
|
||||
.value("Tanh", ActType::Tanh)
|
||||
.export_values();
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::move);
|
||||
py::class_<Handler>(m, "GraphHandler")
|
||||
.def(py::init<Runtime>())
|
||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||
|
@ -104,10 +149,8 @@ void init_graph_builder(py::module &m) {
|
|||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||
const optional<vector<int>> &>(&Handler::pad),
|
||||
policy::move)
|
||||
.def("topo_sort", py::overload_cast<>(&Handler::topo_sort),
|
||||
policy::automatic)
|
||||
.def("operators", py::overload_cast<>(&Handler::operators),
|
||||
policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic);
|
||||
}
|
||||
|
@ -116,5 +159,6 @@ void init_graph_builder(py::module &m) {
|
|||
|
||||
PYBIND11_MODULE(backend, m) {
|
||||
infini::register_operator_timer(m);
|
||||
infini::init_values(m);
|
||||
infini::init_graph_builder(m);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue