feat: 前端支持 reshape

- 无法测试,因为后端不支持 shape 的 INT64 类型

opt: ReshapeObj 构造改为全部传值并在内部 move
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 09:50:32 +08:00
parent ee0a562006
commit 7626efbfa8
7 changed files with 51 additions and 19 deletions

View File

@ -57,6 +57,7 @@ class GraphHandlerObj {
Tensor abs(Tensor x, Tensor y); Tensor abs(Tensor x, Tensor y);
Tensor identity(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y);
Tensor flatten(Tensor s, Tensor y); Tensor flatten(Tensor s, Tensor y);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
}; };
} // namespace infini } // namespace infini

View File

@ -19,7 +19,7 @@ class ReshapeObj : public OperatorObj {
* @param output The output tensor. * @param output The output tensor.
* @param dims The shape of the output tensor. * @param dims The shape of the output tensor.
*/ */
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, const Shape &dims); ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
OP_CLONE(ReshapeObj); OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -1,4 +1,4 @@
import typing, onnx, backend import onnx, backend
runtime = backend.cpu_runtime() runtime = backend.cpu_runtime()
@ -21,7 +21,7 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.output[0]] = handler.matmul( tensors[node.output[0]] = handler.matmul(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
False, False,
False, False,
None, None,
@ -31,7 +31,7 @@ def from_onnx(model: onnx.ModelProto):
(input, mean, var, scale, bias) = ( (input, mean, var, scale, bias) = (
tensors[node.input[i]] for i in [0, 3, 4, 1, 2] tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
) )
output = tensors.get(node.output[0], None) output = tensors.get(node.output[0])
attributes = _parse_attribute( attributes = _parse_attribute(
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
) )
@ -45,61 +45,61 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.output[0]] = handler.add( tensors[node.output[0]] = handler.add(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Sub": elif node.op_type == "Sub":
tensors[node.output[0]] = handler.sub( tensors[node.output[0]] = handler.sub(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Mul": elif node.op_type == "Mul":
tensors[node.output[0]] = handler.mul( tensors[node.output[0]] = handler.mul(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Div": elif node.op_type == "Div":
tensors[node.output[0]] = handler.div( tensors[node.output[0]] = handler.div(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Pow": elif node.op_type == "Pow":
tensors[node.output[0]] = handler.pow( tensors[node.output[0]] = handler.pow(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Relu": elif node.op_type == "Relu":
tensors[node.output[0]] = handler.relu( tensors[node.output[0]] = handler.relu(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Sigmoid": elif node.op_type == "Sigmoid":
tensors[node.output[0]] = handler.sigmoid( tensors[node.output[0]] = handler.sigmoid(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Tanh": elif node.op_type == "Tanh":
tensors[node.output[0]] = handler.tanh( tensors[node.output[0]] = handler.tanh(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Softmax": elif node.op_type == "Softmax":
tensors[node.output[0]] = handler.softmax( tensors[node.output[0]] = handler.softmax(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Abs": elif node.op_type == "Abs":
tensors[node.output[0]] = handler.abs( tensors[node.output[0]] = handler.abs(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Identity": elif node.op_type == "Identity":
tensors[node.output[0]] = handler.identity( tensors[node.output[0]] = handler.identity(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
) )
elif node.op_type == "Flatten": elif node.op_type == "Flatten":
# TODO 后端算子不支持沿任意轴展开 # TODO 后端算子不支持沿任意轴展开
@ -109,7 +109,13 @@ def from_onnx(model: onnx.ModelProto):
assert axis == None or axis == 1 assert axis == None or axis == 1
tensors[node.output[0]] = handler.flatten( tensors[node.output[0]] = handler.flatten(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0], None), tensors.get(node.output[0]),
)
elif node.op_type == "Reshape":
tensors[node.output[0]] = handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
[int(i) for i in tensors[node.input[1]]],
) )
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))

View File

@ -129,6 +129,16 @@ class TestStringMethods(unittest.TestCase):
flatten = make_node("Flatten", ["x"], ["y"], name="flatten") flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
make_and_import_model(make_graph([flatten], "flatten", [x], [y])) make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
# FIXME INT64 类型不支持
# def test_reshape(self):
# data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
# shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
# reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
# reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
# make_and_import_model(
# make_graph([reshape], "reshape", [data, shape], [reshaped])
# )
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression> # see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self): def test_linear(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])

View File

@ -82,6 +82,19 @@ DEFINE_UNARY_METHOD(abs, Abs)
DEFINE_UNARY_METHOD(identity, Identity) DEFINE_UNARY_METHOD(identity, Identity)
DEFINE_UNARY_METHOD(flatten, Flatten) DEFINE_UNARY_METHOD(flatten, Flatten)
Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
if (reshaped) {
g->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
std::move(shape));
return reshaped;
} else {
return g
->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
std::move(shape))
->getOutput();
}
}
static DataType dtype_repr_convert(int dtype) { static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) { switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT: case OnnxDType::FLOAT:

View File

@ -73,6 +73,9 @@ void init_graph_builder(py::module &m) {
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity), .def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
policy::move) policy::move)
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten), .def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
policy::move)
.def("reshape",
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
policy::move); policy::move);
} }

View File

@ -1,9 +1,8 @@
#include "operators/reshape.h" #include "operators/reshape.h"
namespace infini { namespace infini {
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
const Shape &dims) : OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)) {
: OperatorObj(OpType::Reshape, {input}, {output}), dims(dims) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }