forked from jiuyuan/InfiniTensor
temp: 实现初始值导入,但 resnet 报错
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
4ffaa44c1e
commit
ed81861375
|
@ -2,6 +2,8 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -87,6 +89,24 @@ class GraphHandlerObj {
|
|||
|
||||
inline void data_malloc() { g->dataMalloc(); }
|
||||
|
||||
inline void copy_int32(Tensor tensor, std::vector<int32_t> list) {
|
||||
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
|
||||
<< ")" << std::endl;
|
||||
tensor->copyData(list);
|
||||
}
|
||||
|
||||
inline void copy_int64(Tensor tensor, std::vector<int64_t> list) {
|
||||
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
|
||||
<< ")" << std::endl;
|
||||
tensor->copyData(list);
|
||||
}
|
||||
|
||||
inline void copy_float(Tensor tensor, std::vector<float> list) {
|
||||
std::cout << "copy " << list.size() << " floats to (" << tensor->size()
|
||||
<< ")" << std::endl;
|
||||
tensor->copyData(list);
|
||||
}
|
||||
|
||||
inline void run() { g->getRuntime()->run(g); }
|
||||
};
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from onnx.checker import (
|
|||
check_tensor,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any, Tuple, Sequence
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union
|
||||
from functools import reduce
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
@ -324,6 +324,24 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
|||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
handler.data_malloc()
|
||||
|
||||
inputs = []
|
||||
for name, obj in tensors.items():
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
inputs.append((name, tensor))
|
||||
else:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
handler.copy_int64(obj, [int(i) for i in tensor.int64_data])
|
||||
elif tensor.data_type == TensorProto.FLOAT:
|
||||
handler.copy_float(obj, [float(i) for i in tensor.float_data])
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
class Context:
|
||||
|
@ -482,42 +500,6 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
return ctx.build(name)
|
||||
|
||||
|
||||
def parse_onnx(model: ModelProto):
|
||||
print()
|
||||
|
||||
for field in [
|
||||
"doc_string",
|
||||
"domain",
|
||||
"functions",
|
||||
"metadata_props",
|
||||
"model_version",
|
||||
"producer_name",
|
||||
"producer_version",
|
||||
"training_info",
|
||||
]:
|
||||
print("{}: {}".format(field, getattr(model, field)))
|
||||
|
||||
print("ir_version:", model.ir_version)
|
||||
for opset in model.opset_import:
|
||||
print("opset domain={} version={}".format(opset.domain, opset.version))
|
||||
|
||||
print("layout:")
|
||||
for node in model.graph.node:
|
||||
print(
|
||||
' {o} <- {op}"{name}"{a} <- {i}'.format(
|
||||
name=node.name,
|
||||
op=node.op_type,
|
||||
i=node.input,
|
||||
o=node.output,
|
||||
a=[a.name for a in node.attribute],
|
||||
)
|
||||
)
|
||||
|
||||
print("weight:")
|
||||
for node in model.graph.initializer:
|
||||
print(" {}".format(node.name))
|
||||
|
||||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
|
@ -536,11 +518,13 @@ def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[st
|
|||
return attrs
|
||||
|
||||
|
||||
def _parse_data(tensor: TensorProto) -> List[int]:
|
||||
def _parse_data(tensor: TensorProto) -> List[Union[int, float]]:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
return [int(i) for i in tensor.int32_data]
|
||||
elif tensor.data_type == TensorProto.INT64:
|
||||
return [int(i) for i in tensor.int64_data]
|
||||
elif tensor.data_type == TensorProto.FLOAT:
|
||||
return [float(i) for i in tensor.float_data]
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
|
|
|
@ -171,6 +171,9 @@ void init_graph_builder(py::module &m) {
|
|||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("copy_int32", &Handler::copy_int32, policy::automatic)
|
||||
.def("copy_int64", &Handler::copy_int64, policy::automatic)
|
||||
.def("copy_float", &Handler::copy_float, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue