forked from jiuyuan/InfiniTensor
feat: 导出未初始化的张量
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
6dce129cb3
commit
f44a4daf70
|
@ -90,20 +90,14 @@ class GraphHandlerObj {
|
|||
inline void data_malloc() { g->dataMalloc(); }
|
||||
|
||||
inline void copy_int32(Tensor tensor, std::vector<int32_t> list) {
|
||||
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
|
||||
<< ")" << std::endl;
|
||||
tensor->copyData(list);
|
||||
}
|
||||
|
||||
inline void copy_int64(Tensor tensor, std::vector<int64_t> list) {
|
||||
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
|
||||
<< ")" << std::endl;
|
||||
tensor->copyData(list);
|
||||
}
|
||||
|
||||
inline void copy_float(Tensor tensor, std::vector<float> list) {
|
||||
std::cout << "copy " << list.size() << " floats to (" << tensor->size()
|
||||
<< ")" << std::endl;
|
||||
tensor->copyData(list);
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,9 @@ def cuda_runtime():
|
|||
return backend.cuda_runtime()
|
||||
|
||||
|
||||
def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
|
||||
def from_onnx(
|
||||
model: ModelProto, runtime
|
||||
) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]:
|
||||
model = infer_shapes(model)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
|
||||
|
@ -330,12 +332,13 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
|
|||
|
||||
handler.data_malloc()
|
||||
|
||||
inputs = []
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
for name, obj in tensors.items():
|
||||
print("{}: {}".format(name, obj))
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
inputs.append((name, tensor))
|
||||
inputs[name] = obj
|
||||
else:
|
||||
if tensor.data_type == TensorProto.INT32:
|
||||
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
|
||||
|
@ -346,6 +349,8 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
|
|||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
return inputs, handler
|
||||
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
class Context:
|
||||
|
|
Loading…
Reference in New Issue