feat: 支持打印结果

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 10:08:04 +08:00
parent f44a4daf70
commit cf9bdb0562
2 changed files with 7 additions and 2 deletions

View File

@ -34,7 +34,7 @@ def cuda_runtime():
def from_onnx(
model: ModelProto, runtime
) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]:
) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]:
model = infer_shapes(model)
handler = backend.GraphHandler(runtime)
@ -349,7 +349,11 @@ def from_onnx(
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
return inputs, handler
outputs: Dict[str, backend.Tensor] = {}
for output in model.graph.output:
outputs[output.name] = tensors[output.name]
return inputs, outputs, handler
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:

View File

@ -146,6 +146,7 @@ void init_graph_builder(py::module &m) {
#endif
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("shape", &TensorObj::getDims, policy::move)
.def("printData", &TensorObj::printData, policy::automatic)
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)