feat: 导出 Reshape 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 14:30:06 +08:00
parent ffd0473bd2
commit 2a23669394
4 changed files with 26 additions and 19 deletions

View File

@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
inline Shape getShape() const { return dims; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;

View File

@ -450,16 +450,17 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
elif ty == backend.OpType.Reshape:
data = ctx.push_input(inputs[0])
# shape = context.push_data_input(
# name,
# "shape",
# TensorProto.INT32,
# shape=[len(vals)],
# vals=1,
# )
# context.push_node(make_node(ty.name, [data, shape], [name], name))
raise Exception("TODO")
shape = backend.reshape_shape_of(op)
inputs.append(
ctx.push_data_input(
name,
"shape",
TensorProto.INT32,
shape=[len(shape)],
vals=shape,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Concat:
axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))

View File

@ -296,20 +296,17 @@ class TestStringMethods(unittest.TestCase):
handler = backend.GraphHandler(runtime)
a = handler.tensor([1, 2, 3], 12)
b = handler.tensor([1, 2, 3], 12)
ab = handler.tensor([1, 2, 3], 12)
c = handler.tensor([1, 2, 3], 12)
abc = handler.tensor([1, 2, 3], 12)
d = handler.tensor([1, 2, 3], 12)
abcd = handler.tensor([1, 2, 3], 12)
e = handler.tensor([1, 2, 3], 12)
abcde = handler.tensor([1, 2, 3], 12)
handler.add(a, b, ab)
handler.add(ab, c, abc)
handler.add(abc, d, abcd)
handler.add(abcd, e, abcde)
x = handler.add(
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
)
y = handler.tensor([3, 2, 1], 12)
handler.reshape(x, y, [3, 2, 1])
to_onnx(handler, "add")
to_onnx(handler, "test_frontend")
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/concat.h"
#include "operators/gather.h"
#include "operators/reshape.h"
#include <pybind11/stl.h>
#ifdef USE_CUDA
@ -102,12 +103,18 @@ static int gather_axis_of(Operator op) {
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
}
static Shape reshape_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Reshape);
return reinterpret_cast<const ReshapeObj *>(op.get())->getShape();
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype)
.def("reshape_shape_of", &reshape_shape_of)
.def("concat_axis_of", &concat_axis_of)
.def("gather_axis_of", &gather_axis_of);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");