forked from jiuyuan/InfiniTensor
feat: 导出 Reshape 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ffd0473bd2
commit
2a23669394
|
@ -28,6 +28,8 @@ class ReshapeObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getShape() const { return dims; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
|
|
@ -450,16 +450,17 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
data = ctx.push_input(inputs[0])
|
||||
# shape = context.push_data_input(
|
||||
# name,
|
||||
# "shape",
|
||||
# TensorProto.INT32,
|
||||
# shape=[len(vals)],
|
||||
# vals=1,
|
||||
# )
|
||||
# context.push_node(make_node(ty.name, [data, shape], [name], name))
|
||||
raise Exception("TODO")
|
||||
shape = backend.reshape_shape_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name,
|
||||
"shape",
|
||||
TensorProto.INT32,
|
||||
shape=[len(shape)],
|
||||
vals=shape,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Concat:
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
|
|
|
@ -296,20 +296,17 @@ class TestStringMethods(unittest.TestCase):
|
|||
handler = backend.GraphHandler(runtime)
|
||||
a = handler.tensor([1, 2, 3], 12)
|
||||
b = handler.tensor([1, 2, 3], 12)
|
||||
ab = handler.tensor([1, 2, 3], 12)
|
||||
c = handler.tensor([1, 2, 3], 12)
|
||||
abc = handler.tensor([1, 2, 3], 12)
|
||||
d = handler.tensor([1, 2, 3], 12)
|
||||
abcd = handler.tensor([1, 2, 3], 12)
|
||||
e = handler.tensor([1, 2, 3], 12)
|
||||
abcde = handler.tensor([1, 2, 3], 12)
|
||||
|
||||
handler.add(a, b, ab)
|
||||
handler.add(ab, c, abc)
|
||||
handler.add(abc, d, abcd)
|
||||
handler.add(abcd, e, abcde)
|
||||
x = handler.add(
|
||||
handler.add(handler.add(handler.add(a, b, None), c, None), d, None), e, None
|
||||
)
|
||||
y = handler.tensor([3, 2, 1], 12)
|
||||
handler.reshape(x, y, [3, 2, 1])
|
||||
|
||||
to_onnx(handler, "add")
|
||||
to_onnx(handler, "test_frontend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
@ -102,12 +103,18 @@ static int gather_axis_of(Operator op) {
|
|||
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
||||
}
|
||||
|
||||
static Shape reshape_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||
return reinterpret_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.def("tensor_dtype", &tensor_dtype)
|
||||
.def("reshape_shape_of", &reshape_shape_of)
|
||||
.def("concat_axis_of", &concat_axis_of)
|
||||
.def("gather_axis_of", &gather_axis_of);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
|
|
Loading…
Reference in New Issue