diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 82695faa..ff702cf8 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -90,20 +90,14 @@ class GraphHandlerObj { inline void data_malloc() { g->dataMalloc(); } inline void copy_int32(Tensor tensor, std::vector list) { - std::cout << "copy " << list.size() << " ints to (" << tensor->size() - << ")" << std::endl; tensor->copyData(list); } inline void copy_int64(Tensor tensor, std::vector list) { - std::cout << "copy " << list.size() << " ints to (" << tensor->size() - << ")" << std::endl; tensor->copyData(list); } inline void copy_float(Tensor tensor, std::vector list) { - std::cout << "copy " << list.size() << " floats to (" << tensor->size() - << ")" << std::endl; tensor->copyData(list); } diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f815c037..217d230d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,7 +32,9 @@ def cuda_runtime(): return backend.cuda_runtime() -def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: +def from_onnx( + model: ModelProto, runtime +) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]: model = infer_shapes(model) handler = backend.GraphHandler(runtime) @@ -330,12 +332,13 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: handler.data_malloc() - inputs = [] + inputs: Dict[str, backend.Tensor] = {} for name, obj in tensors.items(): + print("{}: {}".format(name, obj)) tensor = data.get(name) if tensor == None: if any(input.name == name for input in model.graph.input): - inputs.append((name, tensor)) + inputs[name] = obj else: if tensor.data_type == TensorProto.INT32: handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) @@ -346,6 +349,8 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler: else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) + return inputs, handler + def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: class Context: