forked from jiuyuan/InfiniTensor
feat: 前端支持 reshape
- 无法测试,因为后端不支持 shape 的 INT64 类型 opt: ReshapeObj 构造改为全部传值并在内部 move Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ee0a562006
commit
7626efbfa8
|
@ -57,6 +57,7 @@ class GraphHandlerObj {
|
||||||
Tensor abs(Tensor x, Tensor y);
|
Tensor abs(Tensor x, Tensor y);
|
||||||
Tensor identity(Tensor x, Tensor y);
|
Tensor identity(Tensor x, Tensor y);
|
||||||
Tensor flatten(Tensor s, Tensor y);
|
Tensor flatten(Tensor s, Tensor y);
|
||||||
|
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -19,7 +19,7 @@ class ReshapeObj : public OperatorObj {
|
||||||
* @param output The output tensor.
|
* @param output The output tensor.
|
||||||
* @param dims The shape of the output tensor.
|
* @param dims The shape of the output tensor.
|
||||||
*/
|
*/
|
||||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, const Shape &dims);
|
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||||
OP_CLONE(ReshapeObj);
|
OP_CLONE(ReshapeObj);
|
||||||
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import typing, onnx, backend
|
import onnx, backend
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.output[0]] = handler.matmul(
|
tensors[node.output[0]] = handler.matmul(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
|
@ -31,7 +31,7 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
(input, mean, var, scale, bias) = (
|
(input, mean, var, scale, bias) = (
|
||||||
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
|
||||||
)
|
)
|
||||||
output = tensors.get(node.output[0], None)
|
output = tensors.get(node.output[0])
|
||||||
attributes = _parse_attribute(
|
attributes = _parse_attribute(
|
||||||
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
|
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
|
||||||
)
|
)
|
||||||
|
@ -45,61 +45,61 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.output[0]] = handler.add(
|
tensors[node.output[0]] = handler.add(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Sub":
|
elif node.op_type == "Sub":
|
||||||
tensors[node.output[0]] = handler.sub(
|
tensors[node.output[0]] = handler.sub(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Mul":
|
elif node.op_type == "Mul":
|
||||||
tensors[node.output[0]] = handler.mul(
|
tensors[node.output[0]] = handler.mul(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Div":
|
elif node.op_type == "Div":
|
||||||
tensors[node.output[0]] = handler.div(
|
tensors[node.output[0]] = handler.div(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Pow":
|
elif node.op_type == "Pow":
|
||||||
tensors[node.output[0]] = handler.pow(
|
tensors[node.output[0]] = handler.pow(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Relu":
|
elif node.op_type == "Relu":
|
||||||
tensors[node.output[0]] = handler.relu(
|
tensors[node.output[0]] = handler.relu(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Sigmoid":
|
elif node.op_type == "Sigmoid":
|
||||||
tensors[node.output[0]] = handler.sigmoid(
|
tensors[node.output[0]] = handler.sigmoid(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Tanh":
|
elif node.op_type == "Tanh":
|
||||||
tensors[node.output[0]] = handler.tanh(
|
tensors[node.output[0]] = handler.tanh(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Softmax":
|
elif node.op_type == "Softmax":
|
||||||
tensors[node.output[0]] = handler.softmax(
|
tensors[node.output[0]] = handler.softmax(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Abs":
|
elif node.op_type == "Abs":
|
||||||
tensors[node.output[0]] = handler.abs(
|
tensors[node.output[0]] = handler.abs(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Identity":
|
elif node.op_type == "Identity":
|
||||||
tensors[node.output[0]] = handler.identity(
|
tensors[node.output[0]] = handler.identity(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Flatten":
|
elif node.op_type == "Flatten":
|
||||||
# TODO 后端算子不支持沿任意轴展开
|
# TODO 后端算子不支持沿任意轴展开
|
||||||
|
@ -109,7 +109,13 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
assert axis == None or axis == 1
|
assert axis == None or axis == 1
|
||||||
tensors[node.output[0]] = handler.flatten(
|
tensors[node.output[0]] = handler.flatten(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0], None),
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "Reshape":
|
||||||
|
tensors[node.output[0]] = handler.reshape(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
[int(i) for i in tensors[node.input[1]]],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
|
@ -129,6 +129,16 @@ class TestStringMethods(unittest.TestCase):
|
||||||
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
||||||
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||||
|
|
||||||
|
# FIXME INT64 类型不支持
|
||||||
|
# def test_reshape(self):
|
||||||
|
# data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||||
|
# shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
||||||
|
# reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
||||||
|
# reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
||||||
|
# make_and_import_model(
|
||||||
|
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
||||||
|
# )
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
|
|
@ -82,6 +82,19 @@ DEFINE_UNARY_METHOD(abs, Abs)
|
||||||
DEFINE_UNARY_METHOD(identity, Identity)
|
DEFINE_UNARY_METHOD(identity, Identity)
|
||||||
DEFINE_UNARY_METHOD(flatten, Flatten)
|
DEFINE_UNARY_METHOD(flatten, Flatten)
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
||||||
|
if (reshaped) {
|
||||||
|
g->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
|
||||||
|
std::move(shape));
|
||||||
|
return reshaped;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
|
||||||
|
std::move(shape))
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -73,6 +73,9 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
.def("identity", py::overload_cast<Tensor, Tensor>(&Handler::identity),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
|
.def("flatten", py::overload_cast<Tensor, Tensor>(&Handler::flatten),
|
||||||
|
policy::move)
|
||||||
|
.def("reshape",
|
||||||
|
py::overload_cast<Tensor, Tensor, Shape>(&Handler::reshape),
|
||||||
policy::move);
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output,
|
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
||||||
const Shape &dims)
|
: OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)) {
|
||||||
: OperatorObj(OpType::Reshape, {input}, {output}), dims(dims) {
|
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue