forked from jiuyuan/InfiniTensor
feat: 支持打印结果
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f44a4daf70
commit
cf9bdb0562
|
@ -34,7 +34,7 @@ def cuda_runtime():
|
|||
|
||||
def from_onnx(
|
||||
model: ModelProto, runtime
|
||||
) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]:
|
||||
) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]:
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
||||
|
@ -349,7 +349,11 @@ def from_onnx(
|
|||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
return inputs, handler
|
||||
outputs: Dict[str, backend.Tensor] = {}
|
||||
for output in model.graph.output:
|
||||
outputs[output.name] = tensors[output.name]
|
||||
|
||||
return inputs, outputs, handler
|
||||
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
|
|
|
@ -146,6 +146,7 @@ void init_graph_builder(py::module &m) {
|
|||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue