feat: 导出未初始化的张量

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 09:39:30 +08:00
parent 6dce129cb3
commit f44a4daf70
2 changed files with 8 additions and 9 deletions

View File

@ -90,20 +90,14 @@ class GraphHandlerObj {
inline void data_malloc() { g->dataMalloc(); }
inline void copy_int32(Tensor tensor, std::vector<int32_t> list) {
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
<< ")" << std::endl;
tensor->copyData(list);
}
inline void copy_int64(Tensor tensor, std::vector<int64_t> list) {
std::cout << "copy " << list.size() << " ints to (" << tensor->size()
<< ")" << std::endl;
tensor->copyData(list);
}
inline void copy_float(Tensor tensor, std::vector<float> list) {
std::cout << "copy " << list.size() << " floats to (" << tensor->size()
<< ")" << std::endl;
tensor->copyData(list);
}

View File

@ -32,7 +32,9 @@ def cuda_runtime():
return backend.cuda_runtime()
def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
def from_onnx(
model: ModelProto, runtime
) -> Tuple[Dict[str, backend.Tensor], backend.GraphHandler]:
model = infer_shapes(model)
handler = backend.GraphHandler(runtime)
@ -330,12 +332,13 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
handler.data_malloc()
inputs = []
inputs: Dict[str, backend.Tensor] = {}
for name, obj in tensors.items():
print("{}: {}".format(name, obj))
tensor = data.get(name)
if tensor == None:
if any(input.name == name for input in model.graph.input):
inputs.append((name, tensor))
inputs[name] = obj
else:
if tensor.data_type == TensorProto.INT32:
handler.copy_int32(obj, [int(i) for i in tensor.int32_data])
@ -346,6 +349,8 @@ def from_onnx(model: ModelProto, runtime) -> backend.GraphHandler:
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
return inputs, handler
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
class Context: