forked from jiuyuan/InfiniTensor
feat: 导出 Gather Concat 到 onnx
- 并优化 python 代码 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
9d9fbd44af
commit
6b7af7077b
|
@ -336,9 +336,10 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
self.count_op[ty] += 1
|
self.count_op[ty] += 1
|
||||||
return ty, name
|
return ty, name
|
||||||
|
|
||||||
def push_output(self, name: str, tensor: backend.Tensor) -> None:
|
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||||
self.names[tensor] = name
|
self.names[tensor] = name
|
||||||
# TODO 需要判断全图输出并保存到 outputs
|
# TODO 需要判断全图输出并保存到 outputs
|
||||||
|
return name
|
||||||
|
|
||||||
def push_input(self, tensor: backend.Tensor) -> str:
|
def push_input(self, tensor: backend.Tensor) -> str:
|
||||||
name = self.names.get(tensor)
|
name = self.names.get(tensor)
|
||||||
|
@ -375,17 +376,17 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
|
|
||||||
ops = graph.operators() # 图中所有算子(节点)
|
ops = graph.operators() # 图中所有算子(节点)
|
||||||
|
|
||||||
context = Context()
|
ctx = Context()
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
ty, name = context.name_op(op)
|
ty, name = ctx.name_op(op)
|
||||||
inputs = op.inputs()
|
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||||
outputs = op.outputs()
|
outputs = [
|
||||||
|
ctx.push_output("{}_{}".format(name, i), it)
|
||||||
|
for (i, it) in enumerate(op.outputs())
|
||||||
|
]
|
||||||
if ty == backend.OpType.Matmul:
|
if ty == backend.OpType.Matmul:
|
||||||
context.push_output(name, outputs[0])
|
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||||
a = context.push_input(inputs[0])
|
|
||||||
b = context.push_input(inputs[1])
|
|
||||||
context.push_node(make_node("MatMul", [a, b], [name], name))
|
|
||||||
elif ty == backend.OpType.BatchNorm:
|
elif ty == backend.OpType.BatchNorm:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.MaxPool:
|
elif ty == backend.OpType.MaxPool:
|
||||||
|
@ -398,12 +399,6 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
backend.OpType.Mul,
|
backend.OpType.Mul,
|
||||||
backend.OpType.Div,
|
backend.OpType.Div,
|
||||||
backend.OpType.Pow,
|
backend.OpType.Pow,
|
||||||
]:
|
|
||||||
context.push_output(name, outputs[0])
|
|
||||||
a = context.push_input(inputs[0])
|
|
||||||
b = context.push_input(inputs[1])
|
|
||||||
context.push_node(make_node(ty.name, [a, b], [name], name))
|
|
||||||
elif ty in [
|
|
||||||
backend.OpType.Relu,
|
backend.OpType.Relu,
|
||||||
backend.OpType.Sigmoid,
|
backend.OpType.Sigmoid,
|
||||||
backend.OpType.Tanh,
|
backend.OpType.Tanh,
|
||||||
|
@ -411,14 +406,11 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
backend.OpType.Abs,
|
backend.OpType.Abs,
|
||||||
backend.OpType.Identity,
|
backend.OpType.Identity,
|
||||||
]:
|
]:
|
||||||
context.push_output(name, outputs[0])
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
x = context.push_input(inputs[0])
|
|
||||||
context.push_node(make_node(ty.name, [x], [name], name))
|
|
||||||
elif ty == backend.OpType.Flatten:
|
elif ty == backend.OpType.Flatten:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.Reshape:
|
elif ty == backend.OpType.Reshape:
|
||||||
context.push_output(name, outputs[0])
|
data = ctx.push_input(inputs[0])
|
||||||
data = context.push_input(inputs[0])
|
|
||||||
# shape = context.push_data_input(
|
# shape = context.push_data_input(
|
||||||
# name,
|
# name,
|
||||||
# "shape",
|
# "shape",
|
||||||
|
@ -429,13 +421,11 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
# context.push_node(make_node(ty.name, [data, shape], [name], name))
|
# context.push_node(make_node(ty.name, [data, shape], [name], name))
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.Concat:
|
elif ty == backend.OpType.Concat:
|
||||||
context.push_output(name, outputs[0])
|
axis = backend.concat_axis_of(op)
|
||||||
a = context.push_input(inputs[0])
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
b = context.push_input(inputs[1])
|
|
||||||
axis = backend.concat_dim_of(op)
|
|
||||||
context.push_node(make_node("Concat", [a, b], [name], name, axis=axis))
|
|
||||||
elif ty == backend.OpType.Gather:
|
elif ty == backend.OpType.Gather:
|
||||||
raise Exception("TODO")
|
axis = backend.gather_axis_of(op)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
elif ty == backend.OpType.ReduceMean:
|
elif ty == backend.OpType.ReduceMean:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.Slice:
|
elif ty == backend.OpType.Slice:
|
||||||
|
@ -446,13 +436,13 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(context.names)
|
print(ctx.names)
|
||||||
print()
|
print()
|
||||||
print(context.inputs)
|
print(ctx.inputs)
|
||||||
print()
|
print()
|
||||||
print(context.outputs)
|
print(ctx.outputs)
|
||||||
print()
|
print()
|
||||||
print(context.nodes)
|
print(ctx.nodes)
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: ModelProto):
|
def parse_onnx(model: ModelProto):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
#include "operators/gather.h"
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
|
@ -91,18 +92,24 @@ static int tensor_dtype(Tensor t) {
|
||||||
IT_ASSERT(false, "Unsupported data type");
|
IT_ASSERT(false, "Unsupported data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
static int concat_dim_of(Operator op) {
|
static int concat_axis_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||||
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
|
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int gather_axis_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Gather);
|
||||||
|
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
||||||
|
}
|
||||||
|
|
||||||
void init_graph_builder(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
|
|
||||||
using Handler = GraphHandlerObj;
|
using Handler = GraphHandlerObj;
|
||||||
|
|
||||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||||
.def("tensor_dtype", &tensor_dtype)
|
.def("tensor_dtype", &tensor_dtype)
|
||||||
.def("concat_dim_of", &concat_dim_of);
|
.def("concat_axis_of", &concat_axis_of)
|
||||||
|
.def("gather_axis_of", &gather_axis_of);
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
m, "CpuRuntime");
|
m, "CpuRuntime");
|
||||||
|
|
Loading…
Reference in New Issue