feat: 导出 Reshape 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 14:30:06 +08:00
parent ffd0473bd2
commit 2a23669394
4 changed files with 26 additions and 19 deletions

View File

@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
inline Shape getShape() const { return dims; }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;

View File

@ -450,16 +450,17 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
elif ty == backend.OpType.Flatten: elif ty == backend.OpType.Flatten:
raise Exception("TODO") raise Exception("TODO")
elif ty == backend.OpType.Reshape: elif ty == backend.OpType.Reshape:
data = ctx.push_input(inputs[0]) shape = backend.reshape_shape_of(op)
# shape = context.push_data_input( inputs.append(
# name, ctx.push_data_input(
# "shape", name,
# TensorProto.INT32, "shape",
# shape=[len(vals)], TensorProto.INT32,
# vals=1, shape=[len(shape)],
# ) vals=shape,
# context.push_node(make_node(ty.name, [data, shape], [name], name)) )
raise Exception("TODO") )
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Concat: elif ty == backend.OpType.Concat:
axis = backend.concat_axis_of(op) axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))

View File

@ -296,20 +296,17 @@ class TestStringMethods(unittest.TestCase):
handler = backend.GraphHandler(runtime) handler = backend.GraphHandler(runtime)
a = handler.tensor([1, 2, 3], 12) a = handler.tensor([1, 2, 3], 12)
b = handler.tensor([1, 2, 3], 12) b = handler.tensor([1, 2, 3], 12)
ab = handler.tensor([1, 2, 3], 12)
c = handler.tensor([1, 2, 3], 12) c = handler.tensor([1, 2, 3], 12)
abc = handler.tensor([1, 2, 3], 12)
d = handler.tensor([1, 2, 3], 12) d = handler.tensor([1, 2, 3], 12)
abcd = handler.tensor([1, 2, 3], 12)
e = handler.tensor([1, 2, 3], 12) e = handler.tensor([1, 2, 3], 12)
abcde = handler.tensor([1, 2, 3], 12)
handler.add(a, b, ab) x = handler.add(
handler.add(ab, c, abc) handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
handler.add(abc, d, abcd) )
handler.add(abcd, e, abcde) y = handler.tensor([3, 2, 1], 12)
handler.reshape(x, y, [3, 2, 1])
to_onnx(handler, "add") to_onnx(handler, "test_frontend")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/reshape.h"
#include <pybind11/stl.h> #include <pybind11/stl.h>
#ifdef USE_CUDA #ifdef USE_CUDA
@ -102,12 +103,18 @@ static int gather_axis_of(Operator op) {
return reinterpret_cast<const GatherObj *>(op.get())->getAxis(); return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
} }
static Shape reshape_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Reshape);
return reinterpret_cast<const ReshapeObj *>(op.get())->getShape();
}
void init_graph_builder(py::module &m) { void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj; using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance) m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype) .def("tensor_dtype", &tensor_dtype)
.def("reshape_shape_of", &reshape_shape_of)
.def("concat_axis_of", &concat_axis_of) .def("concat_axis_of", &concat_axis_of)
.def("gather_axis_of", &gather_axis_of); .def("gather_axis_of", &gather_axis_of);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime"); py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");