forked from jiuyuan/InfiniTensor
Add: different ONNX names for inputs and weights
This commit is contained in:
parent
8bc2d3e48d
commit
0cb8729bc1
|
@ -49,13 +49,20 @@ class OnnxStub:
|
|||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
cnt_infini_inputs = 0
|
||||
for input in model.graph.input:
|
||||
dims = _take_shape_dim(input.type.tensor_type.shape)
|
||||
if input.name.startswith('input'):
|
||||
tensor_type = backend.TensorType.Input
|
||||
cnt_infini_inputs += 1
|
||||
else:
|
||||
tensor_type = backend.TensorType.Initialized
|
||||
tensors[input.name] = ans.handler.tensor(
|
||||
dims,
|
||||
input.type.tensor_type.elem_type,
|
||||
backend.TensorType.Input,
|
||||
tensor_type,
|
||||
)
|
||||
assert cnt_infini_inputs == 1, f'{cnt_infini_inputs} tensor names start with "input" found.'
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
|
@ -625,7 +632,10 @@ class OnnxStub:
|
|||
# means that this input is a global input
|
||||
if name is None:
|
||||
self.count_in += 1
|
||||
name = "input{}".format(self.count_in)
|
||||
if tensor.getTensorType() == backend.TensorType.Input:
|
||||
name = "input{}".format(self.count_in)
|
||||
else:
|
||||
name = "weight{}".format(self.count_in)
|
||||
self.names[tensor] = name
|
||||
if init != None:
|
||||
init.name = name
|
||||
|
|
|
@ -105,8 +105,7 @@ void export_values(py::module &m) {
|
|||
py::enum_<TensorType>(m, "TensorType")
|
||||
.VALUE(TensorType, Input)
|
||||
.VALUE(TensorType, Initialized)
|
||||
.VALUE(TensorType, Other)
|
||||
.export_values();
|
||||
.VALUE(TensorType, Other);
|
||||
#undef VALUE
|
||||
}
|
||||
|
||||
|
@ -309,7 +308,8 @@ void init_graph_builder(py::module &m) {
|
|||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("print_data", &TensorObj::printData)
|
||||
.def("data_malloc", &TensorObj::dataMalloc);
|
||||
.def("data_malloc", &TensorObj::dataMalloc)
|
||||
.def("getTensorType", &TensorObj::getTensorType);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>, Object>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
|
@ -378,22 +378,12 @@ void init_graph_builder(py::module &m) {
|
|||
.def("topo_sort", &GraphObj::topo_sort);
|
||||
}
|
||||
|
||||
Graph getInfoGAN(int batch, Runtime runtime, int nLayers) {
|
||||
IT_TODO_HALT();
|
||||
return nullptr;
|
||||
}
|
||||
vector<Tensor> runInfoGAN(int nLayers) {
|
||||
IT_TODO_HALT();
|
||||
return {};
|
||||
}
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId) {
|
||||
IT_TODO_HALT();
|
||||
return nullptr;
|
||||
}
|
||||
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning) {
|
||||
IT_TODO_HALT();
|
||||
return nullptr;
|
||||
}
|
||||
#ifdef USE_CUDA
|
||||
Graph getInfoGAN(int batch, Runtime runtime, int nLayers);
|
||||
vector<Tensor> runInfoGAN(int nLayers);
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
||||
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning);
|
||||
|
||||
void export_test_model(py::module &m) {
|
||||
m.def("runInfoGAN", &runInfoGAN);
|
||||
m.def("getInfoGAN", &getInfoGAN);
|
||||
|
@ -401,6 +391,7 @@ void export_test_model(py::module &m) {
|
|||
m.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
||||
"tuning"_a = false);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace infini
|
||||
|
||||
|
|
Loading…
Reference in New Issue