diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 217d230d..6b3e394c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -34,7 +34,7 @@ def cuda_runtime(): def from_onnx( model: ModelProto, runtime -) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]: +) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]: model = infer_shapes(model) handler = backend.GraphHandler(runtime) @@ -349,7 +349,11 @@ def from_onnx( else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) - return inputs, handler + outputs: Dict[str, backend.Tensor] = {} + for output in model.graph.output: + outputs[output.name] = tensors[output.name] + + return inputs, outputs, handler def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 5d6085ea..402d5306 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -146,6 +146,7 @@ void init_graph_builder(py::module &m) { #endif py::class_>(m, "Tensor") .def("shape", &TensorObj::getDims, policy::move) + .def("printData", &TensorObj::printData, policy::automatic) .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic)