From 5b6698bac73adc641627da8122e671d07d3734f5 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 15:02:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=BC=E5=87=BA=E5=85=A8=E5=9B=BE?= =?UTF-8?q?=E7=9A=84=E8=BE=93=E5=87=BA=E5=BC=A0=E9=87=8F=E5=88=B0=20onnx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/tensor.h | 8 ++++---- include/core/tensor_base.h | 15 +++++++++------ pyinfinitensor/src/pyinfinitensor/onnx.py | 7 ++++++- src/core/tensor.cc | 10 +++++----- src/ffi/ffi_infinitensor.cc | 1 + 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/include/core/tensor.h b/include/core/tensor.h index 422355e9..1edc950a 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -56,16 +56,16 @@ class TensorObj : public TensorBaseObj { Tensor clone() const { auto obj = make_ref(*this); obj->freeData(); - obj->inputOf.clear(); - obj->outputOf.reset(); + obj->targets.clear(); + obj->source.reset(); return obj; } Tensor clone(Runtime runtime) const { auto obj = make_ref(*this); obj->runtime = runtime; obj->freeData(); - obj->inputOf.clear(); - obj->outputOf.reset(); + obj->targets.clear(); + obj->source.reset(); if (hasData()) { obj->dataMalloc(); obj->copyData(this); diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index af2b97c3..09286a63 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -19,8 +19,8 @@ class TensorBaseObj : public Object { int dim; DataType dtype; - vector> inputOf; - WRef outputOf; + vector> targets; + WRef source; Blob data; Runtime runtime; @@ -46,10 +46,13 @@ class TensorBaseObj : public Object { DataType getDType() const { return dtype; } Runtime getRuntime() const { return runtime; } - void addInputOf(const Operator &op) { inputOf.emplace_back(op); } - void setOutputOf(const Operator &op) { outputOf = op; } - OpVec getInputOf() { return wrefs_to_refs(inputOf); } - Operator getOutputOf() { return outputOf.lock(); } + void addInputOf(const Operator &op) { targets.emplace_back(op); } + void setOutputOf(const Operator &op) { source = op; } + + bool hasTarget() const { return !targets.empty(); } + + OpVec getInputOf() const { return wrefs_to_refs(targets); } + Operator getOutputOf() const { return source.lock(); } // std::pair getOutputOfWithIndex(); // bool setScalar(VType val) { diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 6b3e394c..f8fd662c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -334,7 +334,6 @@ def from_onnx( inputs: Dict[str, backend.Tensor] = {} for name, obj in tensors.items(): - print("{}: {}".format(name, obj)) tensor = data.get(name) if tensor == None: if any(input.name == name for input in model.graph.input): @@ -382,6 +381,12 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: def push_output(self, name: str, tensor: backend.Tensor) -> str: self.names[tensor] = name + if not tensor.has_target(): + shape = tensor.shape() + dtype = backend.tensor_dtype(tensor) + value_info = make_tensor_value_info(name, dtype, shape) + check_value_info(value_info) + self.outputs.append(value_info) return name def push_input(self, tensor: backend.Tensor) -> str: diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 89294a8e..5f04e114 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -18,13 +18,13 @@ string TensorObj::toString() const { std::to_string(fuid) + ", shape " + vecToString(shape) + ", dtype " + dtype.toString(); vector inputOfGuid; - for (const auto &op : inputOf) + for (const auto &op : targets) inputOfGuid.emplace_back(op.lock()->getGuid()); - if (auto o = outputOf.lock()) - ret += ", outputOf " + std::to_string(o->getGuid()); + if (auto o = source.lock()) + ret += ", source " + std::to_string(o->getGuid()); else - ret += ", outputOf None"; - ret += ", inputOf " + vecToString(inputOfGuid); + ret += ", source None"; + ret += ", targets " + vecToString(inputOfGuid); return ret; } diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 2be4971b..53e1376c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -147,6 +147,7 @@ void init_graph_builder(py::module &m) { py::class_>(m, "Tensor") .def("shape", &TensorObj::getDims, policy::move) .def("cloneFloats", &TensorObj::cloneFloats, policy::move) + .def("has_target", &TensorObj::hasTarget, policy::automatic) .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic)