From cf9bdb05629f15a6bfcd9514bbd8a1936728dc11 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 10:08:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=89=93=E5=8D=B0?= =?UTF-8?q?=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- pyinfinitensor/src/pyinfinitensor/onnx.py | 8 ++++++-- src/ffi/ffi_infinitensor.cc | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 217d230d..6b3e394c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -34,7 +34,7 @@ def cuda_runtime(): def from_onnx( model: ModelProto, runtime -) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]: +) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]: model = infer_shapes(model) handler = backend.GraphHandler(runtime) @@ -349,7 +349,11 @@ def from_onnx( else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) - return inputs, handler + outputs: Dict[str, backend.Tensor] = {} + for output in model.graph.output: + outputs[output.name] = tensors[output.name] + + return inputs, outputs, handler def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 5d6085ea..402d5306 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -146,6 +146,7 @@ void init_graph_builder(py::module &m) { #endif py::class_>(m, "Tensor") .def("shape", &TensorObj::getDims, policy::move) + .def("printData", &TensorObj::printData, policy::automatic) .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic)