Add: different ONNX names for inputs and weights

This commit is contained in:
Liyan Zheng 2023-04-20 21:48:17 +08:00
parent 8bc2d3e48d
commit 0cb8729bc1
2 changed files with 22 additions and 21 deletions

View File

@ -49,13 +49,20 @@ class OnnxStub:
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
cnt_infini_inputs = 0
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
if input.name.startswith('input'):
tensor_type = backend.TensorType.Input
cnt_infini_inputs += 1
else:
tensor_type = backend.TensorType.Initialized
tensors[input.name] = ans.handler.tensor(
dims,
input.type.tensor_type.elem_type,
backend.TensorType.Input,
tensor_type,
)
assert cnt_infini_inputs == 1, f'{cnt_infini_inputs} tensor names start with "input" found.'
for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape)
@ -625,7 +632,10 @@ class OnnxStub:
# means that this input is a global input
if name is None:
self.count_in += 1
name = "input{}".format(self.count_in)
if tensor.getTensorType() == backend.TensorType.Input:
name = "input{}".format(self.count_in)
else:
name = "weight{}".format(self.count_in)
self.names[tensor] = name
if init != None:
init.name = name

View File

@ -105,8 +105,7 @@ void export_values(py::module &m) {
py::enum_<TensorType>(m, "TensorType")
.VALUE(TensorType, Input)
.VALUE(TensorType, Initialized)
.VALUE(TensorType, Other)
.export_values();
.VALUE(TensorType, Other);
#undef VALUE
}
@ -309,7 +308,8 @@ void init_graph_builder(py::module &m) {
.def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getSource, policy::move)
.def("print_data", &TensorObj::printData)
.def("data_malloc", &TensorObj::dataMalloc);
.def("data_malloc", &TensorObj::dataMalloc)
.def("getTensorType", &TensorObj::getTensorType);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>, Object>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
@ -378,22 +378,12 @@ void init_graph_builder(py::module &m) {
.def("topo_sort", &GraphObj::topo_sort);
}
Graph getInfoGAN(int batch, Runtime runtime, int nLayers) {
IT_TODO_HALT();
return nullptr;
}
vector<Tensor> runInfoGAN(int nLayers) {
IT_TODO_HALT();
return {};
}
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId) {
IT_TODO_HALT();
return nullptr;
}
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning) {
IT_TODO_HALT();
return nullptr;
}
#ifdef USE_CUDA
Graph getInfoGAN(int batch, Runtime runtime, int nLayers);
vector<Tensor> runInfoGAN(int nLayers);
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning);
void export_test_model(py::module &m) {
m.def("runInfoGAN", &runInfoGAN);
m.def("getInfoGAN", &getInfoGAN);
@ -401,6 +391,7 @@ void export_test_model(py::module &m) {
m.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
"tuning"_a = false);
}
#endif
} // namespace infini