diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index dc221042..82695faa 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -2,6 +2,8 @@ #include "core/graph.h" #include "core/runtime.h" +#include +#include namespace infini { @@ -87,6 +89,24 @@ class GraphHandlerObj { inline void data_malloc() { g->dataMalloc(); } + inline void copy_int32(Tensor tensor, std::vector list) { + std::cout << "copy " << list.size() << " ints to (" << tensor->size() + << ")" << std::endl; + tensor->copyData(list); + } + + inline void copy_int64(Tensor tensor, std::vector list) { + std::cout << "copy " << list.size() << " ints to (" << tensor->size() + << ")" << std::endl; + tensor->copyData(list); + } + + inline void copy_float(Tensor tensor, std::vector list) { + std::cout << "copy " << list.size() << " floats to (" << tensor->size() + << ")" << std::endl; + tensor->copyData(list); + } + inline void run() { g->getRuntime()->run(g); } }; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ce315685..207280ef 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -22,7 +22,7 @@ from onnx.checker import ( check_tensor, ) from onnx.shape_inference import infer_shapes -from typing import Dict, List, Any, Tuple, Sequence +from typing import Dict, List, Any, Tuple, Sequence, Union from functools import reduce runtime = backend.cpu_runtime() @@ -324,6 +324,24 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler: else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) + handler.data_malloc() + + inputs = [] + for name, obj in tensors.items(): + tensor = data.get(name) + if tensor == None: + if any(input.name == name for input in model.graph.input): + inputs.append((name, tensor)) + else: + if tensor.data_type == TensorProto.INT32: + handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) + elif tensor.data_type == TensorProto.INT64: + handler.copy_int64(obj, [int(i) for i in tensor.int64_data]) + elif tensor.data_type == TensorProto.FLOAT: + handler.copy_float(obj, [float(i) for i in tensor.float_data]) + else: + assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) + def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: class Context: @@ -482,42 +500,6 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: return ctx.build(name) -def parse_onnx(model: ModelProto): - print() - - for field in [ - "doc_string", - "domain", - "functions", - "metadata_props", - "model_version", - "producer_name", - "producer_version", - "training_info", - ]: - print("{}: {}".format(field, getattr(model, field))) - - print("ir_version:", model.ir_version) - for opset in model.opset_import: - print("opset domain={} version={}".format(opset.domain, opset.version)) - - print("layout:") - for node in model.graph.node: - print( - ' {o} <- {op}"{name}"{a} <- {i}'.format( - name=node.name, - op=node.op_type, - i=node.input, - o=node.output, - a=[a.name for a in node.attribute], - ) - ) - - print("weight:") - for node in model.graph.initializer: - print(" {}".format(node.name)) - - def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: if attr.name in attrs: @@ -536,11 +518,13 @@ def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[st return attrs -def _parse_data(tensor: TensorProto) -> List[int]: +def _parse_data(tensor: TensorProto) -> List[Union[int, float]]: if tensor.data_type == TensorProto.INT32: return [int(i) for i in tensor.int32_data] elif tensor.data_type == TensorProto.INT64: return [int(i) for i in tensor.int64_data] + elif tensor.data_type == TensorProto.FLOAT: + return [float(i) for i in tensor.float_data] else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index e545f43c..b791e66e 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -171,6 +171,9 @@ void init_graph_builder(py::module &m) { .def("topo_sort", &Handler::topo_sort, policy::automatic) .def("operators", &Handler::operators, policy::move) .def("data_malloc", &Handler::data_malloc, policy::automatic) + .def("copy_int32", &Handler::copy_int32, policy::automatic) + .def("copy_int64", &Handler::copy_int64, policy::automatic) + .def("copy_float", &Handler::copy_float, policy::automatic) .def("run", &Handler::run, policy::automatic); }