forked from jiuyuan/InfiniTensor
feat: 导出全图的输出张量到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
59bf59c10b
commit
5b6698bac7
|
@ -56,16 +56,16 @@ class TensorObj : public TensorBaseObj {
|
|||
Tensor clone() const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->freeData();
|
||||
obj->inputOf.clear();
|
||||
obj->outputOf.reset();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
return obj;
|
||||
}
|
||||
Tensor clone(Runtime runtime) const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->runtime = runtime;
|
||||
obj->freeData();
|
||||
obj->inputOf.clear();
|
||||
obj->outputOf.reset();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
if (hasData()) {
|
||||
obj->dataMalloc();
|
||||
obj->copyData(this);
|
||||
|
|
|
@ -19,8 +19,8 @@ class TensorBaseObj : public Object {
|
|||
int dim;
|
||||
|
||||
DataType dtype;
|
||||
vector<WRef<OperatorObj>> inputOf;
|
||||
WRef<OperatorObj> outputOf;
|
||||
vector<WRef<OperatorObj>> targets;
|
||||
WRef<OperatorObj> source;
|
||||
Blob data;
|
||||
Runtime runtime;
|
||||
|
||||
|
@ -46,10 +46,13 @@ class TensorBaseObj : public Object {
|
|||
DataType getDType() const { return dtype; }
|
||||
Runtime getRuntime() const { return runtime; }
|
||||
|
||||
void addInputOf(const Operator &op) { inputOf.emplace_back(op); }
|
||||
void setOutputOf(const Operator &op) { outputOf = op; }
|
||||
OpVec getInputOf() { return wrefs_to_refs(inputOf); }
|
||||
Operator getOutputOf() { return outputOf.lock(); }
|
||||
void addInputOf(const Operator &op) { targets.emplace_back(op); }
|
||||
void setOutputOf(const Operator &op) { source = op; }
|
||||
|
||||
bool hasTarget() const { return !targets.empty(); }
|
||||
|
||||
OpVec getInputOf() const { return wrefs_to_refs(targets); }
|
||||
Operator getOutputOf() const { return source.lock(); }
|
||||
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||
|
||||
// bool setScalar(VType val) {
|
||||
|
|
|
@ -334,7 +334,6 @@ def from_onnx(
|
|||
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
for name, obj in tensors.items():
|
||||
print("{}: {}".format(name, obj))
|
||||
tensor = data.get(name)
|
||||
if tensor == None:
|
||||
if any(input.name == name for input in model.graph.input):
|
||||
|
@ -382,6 +381,12 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||
self.names[tensor] = name
|
||||
if not tensor.has_target():
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.outputs.append(value_info)
|
||||
return name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
|
|
|
@ -18,13 +18,13 @@ string TensorObj::toString() const {
|
|||
std::to_string(fuid) + ", shape " + vecToString(shape) +
|
||||
", dtype " + dtype.toString();
|
||||
vector<UidBaseType> inputOfGuid;
|
||||
for (const auto &op : inputOf)
|
||||
for (const auto &op : targets)
|
||||
inputOfGuid.emplace_back(op.lock()->getGuid());
|
||||
if (auto o = outputOf.lock())
|
||||
ret += ", outputOf " + std::to_string(o->getGuid());
|
||||
if (auto o = source.lock())
|
||||
ret += ", source " + std::to_string(o->getGuid());
|
||||
else
|
||||
ret += ", outputOf None";
|
||||
ret += ", inputOf " + vecToString(inputOfGuid);
|
||||
ret += ", source None";
|
||||
ret += ", targets " + vecToString(inputOfGuid);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -147,6 +147,7 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue