feat: 导出全图的输出张量到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 15:02:52 +08:00
parent 59bf59c10b
commit 5b6698bac7
5 changed files with 25 additions and 16 deletions

View File

@ -56,16 +56,16 @@ class TensorObj : public TensorBaseObj {
Tensor clone() const {
auto obj = make_ref<TensorObj>(*this);
obj->freeData();
obj->inputOf.clear();
obj->outputOf.reset();
obj->targets.clear();
obj->source.reset();
return obj;
}
Tensor clone(Runtime runtime) const {
auto obj = make_ref<TensorObj>(*this);
obj->runtime = runtime;
obj->freeData();
obj->inputOf.clear();
obj->outputOf.reset();
obj->targets.clear();
obj->source.reset();
if (hasData()) {
obj->dataMalloc();
obj->copyData(this);

View File

@ -19,8 +19,8 @@ class TensorBaseObj : public Object {
int dim;
DataType dtype;
vector<WRef<OperatorObj>> inputOf;
WRef<OperatorObj> outputOf;
vector<WRef<OperatorObj>> targets;
WRef<OperatorObj> source;
Blob data;
Runtime runtime;
@ -46,10 +46,13 @@ class TensorBaseObj : public Object {
DataType getDType() const { return dtype; }
Runtime getRuntime() const { return runtime; }
void addInputOf(const Operator &op) { inputOf.emplace_back(op); }
void setOutputOf(const Operator &op) { outputOf = op; }
OpVec getInputOf() { return wrefs_to_refs(inputOf); }
Operator getOutputOf() { return outputOf.lock(); }
void addInputOf(const Operator &op) { targets.emplace_back(op); }
void setOutputOf(const Operator &op) { source = op; }
bool hasTarget() const { return !targets.empty(); }
OpVec getInputOf() const { return wrefs_to_refs(targets); }
Operator getOutputOf() const { return source.lock(); }
// std::pair<Operator *, int> getOutputOfWithIndex();
// bool setScalar(VType val) {

View File

@ -334,7 +334,6 @@ def from_onnx(
inputs: Dict[str, backend.Tensor] = {}
for name, obj in tensors.items():
print("{}: {}".format(name, obj))
tensor = data.get(name)
if tensor == None:
if any(input.name == name for input in model.graph.input):
@ -382,6 +381,12 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
def push_output(self, name: str, tensor: backend.Tensor) -> str:
self.names[tensor] = name
if not tensor.has_target():
shape = tensor.shape()
dtype = backend.tensor_dtype(tensor)
value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
self.outputs.append(value_info)
return name
def push_input(self, tensor: backend.Tensor) -> str:

View File

@ -18,13 +18,13 @@ string TensorObj::toString() const {
std::to_string(fuid) + ", shape " + vecToString(shape) +
", dtype " + dtype.toString();
vector<UidBaseType> inputOfGuid;
for (const auto &op : inputOf)
for (const auto &op : targets)
inputOfGuid.emplace_back(op.lock()->getGuid());
if (auto o = outputOf.lock())
ret += ", outputOf " + std::to_string(o->getGuid());
if (auto o = source.lock())
ret += ", source " + std::to_string(o->getGuid());
else
ret += ", outputOf None";
ret += ", inputOf " + vecToString(inputOfGuid);
ret += ", source None";
ret += ", targets " + vecToString(inputOfGuid);
return ret;
}

View File

@ -147,6 +147,7 @@ void init_graph_builder(py::module &m) {
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("shape", &TensorObj::getDims, policy::move)
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
.def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)