temp: 实现初始值导入,但 resnet 报错

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-23 15:29:16 +08:00
parent 4ffaa44c1e
commit ed81861375
3 changed files with 45 additions and 38 deletions

View File

@ -2,6 +2,8 @@
#include "core/graph.h" #include "core/graph.h"
#include "core/runtime.h" #include "core/runtime.h"
#include <cstdint>
#include <iostream>
namespace infini { namespace infini {
@ -87,6 +89,24 @@ class GraphHandlerObj {
inline void data_malloc() { g->dataMalloc(); } inline void data_malloc() { g->dataMalloc(); }
inline void copy_int32(Tensor tensor, std::vector<int32_t> list) {
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
<< ")" << std::endl;
tensor->copyData(list);
}
inline void copy_int64(Tensor tensor, std::vector<int64_t> list) {
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
<< ")" << std::endl;
tensor->copyData(list);
}
inline void copy_float(Tensor tensor, std::vector<float> list) {
std::cout << "copy " << list.size() << " floats to (" << tensor->size()
<< ")" << std::endl;
tensor->copyData(list);
}
inline void run() { g->getRuntime()->run(g); } inline void run() { g->getRuntime()->run(g); }
}; };

View File

@ -22,7 +22,7 @@ from onnx.checker import (
check_tensor, check_tensor,
) )
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple, Sequence from typing import Dict, List, Any, Tuple, Sequence, Union
from functools import reduce from functools import reduce
runtime = backend.cpu_runtime() runtime = backend.cpu_runtime()
@ -324,6 +324,24 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
handler.data_malloc()
inputs = []
for name, obj in tensors.items():
tensor = data.get(name)
if tensor == None:
if any(input.name == name for input in model.graph.input):
inputs.append((name, tensor))
else:
if tensor.data_type == TensorProto.INT32:
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
elif tensor.data_type == TensorProto.INT64:
handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
elif tensor.data_type == TensorProto.FLOAT:
handler.copy_float(obj, [float(i) for i in tensor.float_data])
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
class Context: class Context:
@ -482,42 +500,6 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
return ctx.build(name) return ctx.build(name)
def parse_onnx(model: ModelProto):
print()
for field in [
"doc_string",
"domain",
"functions",
"metadata_props",
"model_version",
"producer_name",
"producer_version",
"training_info",
]:
print("{}: {}".format(field, getattr(model, field)))
print("ir_version:", model.ir_version)
for opset in model.opset_import:
print("opset domain={} version={}".format(opset.domain, opset.version))
print("layout:")
for node in model.graph.node:
print(
' {o} <- {op}"{name}"{a} <- {i}'.format(
name=node.name,
op=node.op_type,
i=node.input,
o=node.output,
a=[a.name for a in node.attribute],
)
)
print("weight:")
for node in model.graph.initializer:
print(" {}".format(node.name))
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
for attr in node.attribute: for attr in node.attribute:
if attr.name in attrs: if attr.name in attrs:
@ -536,11 +518,13 @@ def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[st
return attrs return attrs
def _parse_data(tensor: TensorProto) -> List[int]: def _parse_data(tensor: TensorProto) -> List[Union[int, float]]:
if tensor.data_type == TensorProto.INT32: if tensor.data_type == TensorProto.INT32:
return [int(i) for i in tensor.int32_data] return [int(i) for i in tensor.int32_data]
elif tensor.data_type == TensorProto.INT64: elif tensor.data_type == TensorProto.INT64:
return [int(i) for i in tensor.int64_data] return [int(i) for i in tensor.int64_data]
elif tensor.data_type == TensorProto.FLOAT:
return [float(i) for i in tensor.float_data]
else: else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)

View File

@ -171,6 +171,9 @@ void init_graph_builder(py::module &m) {
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc, policy::automatic)
.def("copy_int32", &Handler::copy_int32, policy::automatic)
.def("copy_int64", &Handler::copy_int64, policy::automatic)
.def("copy_float", &Handler::copy_float, policy::automatic)
.def("run", &Handler::run, policy::automatic); .def("run", &Handler::run, policy::automatic);
} }