forked from jiuyuan/InfiniTensor
feat: 导出 Reshape 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ffd0473bd2
commit
2a23669394
|
@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
|
||||||
int numInputs() const override { return 1; }
|
int numInputs() const override { return 1; }
|
||||||
int numOutputs() const override { return 1; }
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
inline Shape getShape() const { return dims; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
|
|
|
@ -450,16 +450,17 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||||
elif ty == backend.OpType.Flatten:
|
elif ty == backend.OpType.Flatten:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.Reshape:
|
elif ty == backend.OpType.Reshape:
|
||||||
data = ctx.push_input(inputs[0])
|
shape = backend.reshape_shape_of(op)
|
||||||
# shape = context.push_data_input(
|
inputs.append(
|
||||||
# name,
|
ctx.push_data_input(
|
||||||
# "shape",
|
name,
|
||||||
# TensorProto.INT32,
|
"shape",
|
||||||
# shape=[len(vals)],
|
TensorProto.INT32,
|
||||||
# vals=1,
|
shape=[len(shape)],
|
||||||
# )
|
vals=shape,
|
||||||
# context.push_node(make_node(ty.name, [data, shape], [name], name))
|
)
|
||||||
raise Exception("TODO")
|
)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
elif ty == backend.OpType.Concat:
|
elif ty == backend.OpType.Concat:
|
||||||
axis = backend.concat_axis_of(op)
|
axis = backend.concat_axis_of(op)
|
||||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
|
|
|
@ -296,20 +296,17 @@ class TestStringMethods(unittest.TestCase):
|
||||||
handler = backend.GraphHandler(runtime)
|
handler = backend.GraphHandler(runtime)
|
||||||
a = handler.tensor([1, 2, 3], 12)
|
a = handler.tensor([1, 2, 3], 12)
|
||||||
b = handler.tensor([1, 2, 3], 12)
|
b = handler.tensor([1, 2, 3], 12)
|
||||||
ab = handler.tensor([1, 2, 3], 12)
|
|
||||||
c = handler.tensor([1, 2, 3], 12)
|
c = handler.tensor([1, 2, 3], 12)
|
||||||
abc = handler.tensor([1, 2, 3], 12)
|
|
||||||
d = handler.tensor([1, 2, 3], 12)
|
d = handler.tensor([1, 2, 3], 12)
|
||||||
abcd = handler.tensor([1, 2, 3], 12)
|
|
||||||
e = handler.tensor([1, 2, 3], 12)
|
e = handler.tensor([1, 2, 3], 12)
|
||||||
abcde = handler.tensor([1, 2, 3], 12)
|
|
||||||
|
|
||||||
handler.add(a, b, ab)
|
x = handler.add(
|
||||||
handler.add(ab, c, abc)
|
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
|
||||||
handler.add(abc, d, abcd)
|
)
|
||||||
handler.add(abcd, e, abcde)
|
y = handler.tensor([3, 2, 1], 12)
|
||||||
|
handler.reshape(x, y, [3, 2, 1])
|
||||||
|
|
||||||
to_onnx(handler, "add")
|
to_onnx(handler, "test_frontend")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
|
#include "operators/reshape.h"
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
|
@ -102,12 +103,18 @@ static int gather_axis_of(Operator op) {
|
||||||
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Shape reshape_shape_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||||
|
return reinterpret_cast<const ReshapeObj *>(op.get())->getShape();
|
||||||
|
}
|
||||||
|
|
||||||
void init_graph_builder(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
|
|
||||||
using Handler = GraphHandlerObj;
|
using Handler = GraphHandlerObj;
|
||||||
|
|
||||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||||
.def("tensor_dtype", &tensor_dtype)
|
.def("tensor_dtype", &tensor_dtype)
|
||||||
|
.def("reshape_shape_of", &reshape_shape_of)
|
||||||
.def("concat_axis_of", &concat_axis_of)
|
.def("concat_axis_of", &concat_axis_of)
|
||||||
.def("gather_axis_of", &gather_axis_of);
|
.def("gather_axis_of", &gather_axis_of);
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||||
|
|
Loading…
Reference in New Issue