From f44a4daf70d734e61dcbfa47d579743620eeb35e Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 09:39:30 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=BC=E5=87=BA=E6=9C=AA=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E7=9A=84=E5=BC=A0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/graph_handler.h | 6 ------ pyinfinitensor/src/pyinfinitensor/onnx.py | 11 ++++++++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 82695faa..ff702cf8 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -90,20 +90,14 @@ class GraphHandlerObj { inline void data_malloc() { g->dataMalloc(); } inline void copy_int32(Tensor tensor, std::vector list) { - std::cout << "copy " << list.size() << " ints to (" << tensor->size() - << ")" << std::endl; tensor->copyData(list); } inline void copy_int64(Tensor tensor, std::vector list) { - std::cout << "copy " << list.size() << " ints to (" << tensor->size() - << ")" << std::endl; tensor->copyData(list); } inline void copy_float(Tensor tensor, std::vector list) { - std::cout << "copy " << list.size() << " floats to (" << tensor->size() - << ")" << std::endl; tensor->copyData(list); } diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f815c037..217d230d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,7 +32,9 @@ def cuda_runtime(): return backend.cuda_runtime() -def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: +def from_onnx( + model: ModelProto, runtime +) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]: model = infer_shapes(model) handler = backend.GraphHandler(runtime) @@ -330,12 +332,13 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: handler.data_malloc() - inputs = [] + inputs: Dict[str, backend.Tensor] = {} for name, obj in tensors.items(): + print("{}: {}".format(name, obj)) tensor = data.get(name) if tensor == None: if any(input.name == name for input in model.graph.input): - inputs.append((name, tensor)) + inputs[name] = obj else: if tensor.data_type == TensorProto.INT32: handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) @@ -346,6 +349,8 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) + return inputs, handler + def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: class Context: