解除前端对onnx infershape功能的依赖 (#206)

* feat: SqueezeOp lift the dependency of onnx infershape.

* feat: UnsqueezeOp lift the dependency of onnx infershape.

* feat: lift the dependency of onnx infershape

* fix: fix Makefile off nccl
This commit is contained in:
zhangyunze 2024-01-12 14:54:27 +08:00 committed by GitHub
parent 46e61a5bd4
commit 58993d4339
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 358 additions and 71 deletions

View File

@ -29,7 +29,6 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
CMAKE_OPT += -DBUILD_TEST=$(TEST) CMAKE_OPT += -DBUILD_TEST=$(TEST)
CMAKE_OPT += -DBUILD_DIST=ON
CMAKE_OPT += -DBUILD_NNET=$(NNET) CMAKE_OPT += -DBUILD_NNET=$(NNET)
ifeq ($(INTELCPU), ON) ifeq ($(INTELCPU), ON)

View File

@ -71,6 +71,8 @@ class GraphHandlerObj {
vector<float> scales_, vector<float> roi_, string mode, vector<float> scales_, vector<float> roi_, string mode,
string ratioPolicy, string nearestMode, string ratioPolicy, string nearestMode,
string coordTransMode); string coordTransMode);
Tensor squeeze(Tensor input, Tensor output, Shape axes);
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor concat(TensorVec inputs, Tensor output, int dim);
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v, Tensor input_q, Tensor input_k, Tensor input_v,

View File

@ -0,0 +1,39 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Remove single-dimensional entries from the shape of a tensor.
*
*/
class SqueezeObj : public OperatorObj {
Shape axes;
public:
/**
* @brief Construct a new Squeeze object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
* @param axes List of integers indicating the dimensions to squeeze.
*/
SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
OP_CLONE(SqueezeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
inline Shape getAxes() const { return axes; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -0,0 +1,38 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief nsert single-dimensional entries to the shape of an input tensor.
*
*/
class UnsqueezeObj : public OperatorObj {
Shape axes;
public:
/**
* @brief Construct a new Unsqueeze object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
* @param axes List of integers indicating the dimensions to be inserted.
*/
UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
OP_CLONE(UnsqueezeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
inline Shape getAxes() const { return axes; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -52,10 +52,10 @@ class OnnxStub:
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {} self.initializer: Dict[int, TensorProto] = {}
try: # try:
model = infer_shapes(model) # model = infer_shapes(model)
except: # except:
warnings.warn("infer_shapes failed.") # warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime) self.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict() tensors: Dict[str, backend.Tensor] = dict()
@ -135,7 +135,7 @@ class OnnxStub:
1, 1,
reduce( reduce(
lambda acc, x: acc * x, lambda acc, x: acc * x,
_search_shape(model, node.input[2]), tensors[node.input[2]].shape(),
), ),
1, 1,
1, 1,
@ -357,7 +357,7 @@ class OnnxStub:
ceil_mode, ceil_mode,
) )
elif node.op_type == "GlobalAveragePool": elif node.op_type == "GlobalAveragePool":
[_, _, h, w] = _search_shape(model, node.input[0]) [_, _, h, w] = tensors[node.input[0]].shape()
tensors[node.output[0]] = self.handler.avgPool( tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
@ -595,35 +595,43 @@ class OnnxStub:
coordinate_transformation_mode, coordinate_transformation_mode,
) )
elif node.op_type == "Squeeze": elif node.op_type == "Squeeze":
input_shape = _search_shape(model, node.input[0]) axes = (
axes = set( _parse_data(data[node.input[1]])
[int(i) for i in data[node.input[1]].int64_data]
if len(node.input) > 1 if len(node.input) > 1
else _parse_attribute(node, {"axes": None})["axes"] else None
) )
assert all(input_shape[d] == 1 for d in axes) if axes is None:
output_shape = [] axes = next(
for i, x in enumerate(input_shape): (
if i not in axes: attr.ints
output_shape.append(x) for attr in node.attribute
tensors[node.output[0]] = self.handler.reshape( if attr.name == "axes"
),
[],
)
tensors[node.output[0]] = self.handler.squeeze(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
output_shape, axes,
) )
elif node.op_type == "Unsqueeze": elif node.op_type == "Unsqueeze":
input_shape = _search_shape(model, node.input[0])
axes = ( axes = (
[int(i) for i in data[node.input[1]].int64_data] _parse_data(data[node.input[1]])
if len(node.input) > 1 if len(node.input) > 1
else _parse_attribute(node, {"axes": None})["axes"] else None
) )
for i in axes: if axes is None:
input_shape.insert(i, 1) axes = next(
tensors[node.output[0]] = self.handler.reshape( (
attr.ints
for attr in node.attribute
if attr.name == "axes"
)
)
tensors[node.output[0]] = self.handler.unsqueeze(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
input_shape, axes,
) )
elif node.op_type == "Concat": elif node.op_type == "Concat":
tensors[node.output[0]] = self.handler.concat( tensors[node.output[0]] = self.handler.concat(
@ -935,8 +943,7 @@ class OnnxStub:
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1} node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
) )
(alpha, beta, bias, size) = ( (alpha, beta, bias, size) = (
attributes[name] attributes[name] for name in ["alpha", "beta", "bias", "size"]
for name in ["alpha", "beta", "bias", "size"]
) )
tensors[node.output[0]] = self.handler.lrn( tensors[node.output[0]] = self.handler.lrn(
tensors[node.input[0]], tensors[node.input[0]],
@ -1207,6 +1214,30 @@ class OnnxStub:
) )
) )
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Squeeze:
axes = backend.squeeze_axes_of(op)
inputs.append(
ctx.push_data_input(
name,
"axes",
TensorProto.INT64,
[len(axes)],
axes,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Unsqueeze:
axes = backend.unsqueeze_axes_of(op)
inputs.append(
ctx.push_data_input(
name,
"axes",
TensorProto.INT64,
[len(axes)],
axes,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Concat: elif ty == backend.OpTypeId.Concat:
axis = backend.concat_axis_of(op) axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
@ -1344,50 +1375,6 @@ def from_onnx(model: ModelProto, runtime):
return stub.inputs, stub.outputs, stub.handler return stub.inputs, stub.outputs, stub.handler
def _search_shape(model: ModelProto, name: str) -> List[int]:
ans = (
next(
(
[
(d.dim_value if d.dim_value > 0 else 1)
for d in tensor.type.tensor_type.shape.dim
]
for tensor in model.graph.value_info
if tensor.name == name
),
None,
)
or next(
(
[
(d.dim_value if d.dim_value > 0 else 1)
for d in tensor.type.tensor_type.shape.dim
]
for tensor in model.graph.input
if tensor.name == name
),
None,
)
or next(
(
[
(d.dim_value if d.dim_value > 0 else 1)
for d in tensor.type.tensor_type.shape.dim
]
for tensor in model.graph.output
if tensor.name == name
),
None,
)
or next(
[int(d) for d in tensor.dims]
for tensor in model.graph.initializer
if tensor.name == name
)
)
return ans
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
for attr in node.attribute: for attr in node.attribute:
if attr.type == AttributeProto.INT: if attr.type == AttributeProto.INT:

View File

@ -303,6 +303,28 @@ class TestStringMethods(unittest.TestCase):
reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize") reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize")
make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales])) make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales]))
def test_squeeze(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 5])
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
output = make_tensor_value_info("output", TensorProto.FLOAT, [3, 5])
squeeze = make_node("Squeeze", ["input", "axes"], ["output"], name="squeeze")
make_and_import_model(
make_graph([squeeze], "squeeze", [input, axes], [output], [axes_data])
)
def test_unsqueeze(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 5])
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 3, 4, 5])
unsqueeze = make_node(
"Unsqueeze", ["input", "axes"], ["output"], name="unsqueeze"
)
make_and_import_model(
make_graph([unsqueeze], "unsqueeze", [input, axes], [output], [axes_data])
)
def test_concat(self): def test_concat(self):
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4]) input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5]) input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5])

View File

@ -22,8 +22,10 @@
#include "operators/slice.h" #include "operators/slice.h"
#include "operators/softmax.h" #include "operators/softmax.h"
#include "operators/split.h" #include "operators/split.h"
#include "operators/squeeze.h"
#include "operators/transpose.h" #include "operators/transpose.h"
#include "operators/unary.h" #include "operators/unary.h"
#include "operators/unsqueeze.h"
#include "operators/where.h" #include "operators/where.h"
#include <numeric> #include <numeric>
#include <variant> #include <variant>
@ -608,6 +610,28 @@ Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
} }
} }
Tensor GraphHandlerObj::squeeze(Tensor input, Tensor output, Shape axes) {
if (output) {
g->addOpWithOutputs<SqueezeObj>(std::move(input), output,
std::move(axes));
return output;
} else {
return g->addOp<SqueezeObj>(std::move(input), output, std::move(axes))
->getOutput();
}
}
Tensor GraphHandlerObj::unsqueeze(Tensor input, Tensor output, Shape axes) {
if (output) {
g->addOpWithOutputs<UnsqueezeObj>(std::move(input), output,
std::move(axes));
return output;
} else {
return g->addOp<UnsqueezeObj>(std::move(input), output, std::move(axes))
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) { static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType(); auto iType = input->getDType();
auto oType = DataType(to); auto oType = DataType(to);

View File

@ -12,8 +12,10 @@
#include "operators/reduce.h" #include "operators/reduce.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include "operators/split.h" #include "operators/split.h"
#include "operators/squeeze.h"
#include "operators/transpose.h" #include "operators/transpose.h"
#include "operators/unary.h" #include "operators/unary.h"
#include "operators/unsqueeze.h"
#include <algorithm> #include <algorithm>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
@ -93,6 +95,8 @@ void export_values(py::module &m) {
.VALUE(OpType, ReduceMean) .VALUE(OpType, ReduceMean)
.VALUE(OpType, ReduceSum) .VALUE(OpType, ReduceSum)
.VALUE(OpType, Reshape) .VALUE(OpType, Reshape)
.VALUE(OpType, Squeeze)
.VALUE(OpType, Unsqueeze)
.VALUE(OpType, Flatten) .VALUE(OpType, Flatten)
.VALUE(OpType, Identity) .VALUE(OpType, Identity)
.VALUE(OpType, BatchNormalization) .VALUE(OpType, BatchNormalization)
@ -256,6 +260,24 @@ static vector<int64_t> reshape_shape_of(Operator op) {
return ans; return ans;
} }
static vector<int64_t> squeeze_axes_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Squeeze);
auto axes = dynamic_cast<const SqueezeObj *>(op.get())->getAxes();
vector<int64_t> ans(axes.size());
std::transform(axes.begin(), axes.end(), ans.begin(),
[](auto x) { return static_cast<int64_t>(x); });
return ans;
}
static vector<int64_t> unsqueeze_axes_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Unsqueeze);
auto axes = dynamic_cast<const UnsqueezeObj *>(op.get())->getAxes();
vector<int64_t> ans(axes.size());
std::transform(axes.begin(), axes.end(), ans.begin(),
[](auto x) { return static_cast<int64_t>(x); });
return ans;
}
static vector<int64_t> expand_shape_of(Operator op) { static vector<int64_t> expand_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Expand); IT_ASSERT(op->getOpType() == OpType::Expand);
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape(); auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
@ -343,6 +365,8 @@ void export_functions(py::module &m) {
.FUNCTION(flatten_axis_of) .FUNCTION(flatten_axis_of)
.FUNCTION(cast_to_of) .FUNCTION(cast_to_of)
.FUNCTION(depth_to_space_attrs_of) .FUNCTION(depth_to_space_attrs_of)
.FUNCTION(squeeze_axes_of)
.FUNCTION(unsqueeze_axes_of)
.FUNCTION(lrn_attrs_of); .FUNCTION(lrn_attrs_of);
#undef FUNCTION #undef FUNCTION
} }
@ -509,6 +533,8 @@ void init_graph_builder(py::module &m) {
.def("depthToSpace", &Handler::depthToSpace, policy::move) .def("depthToSpace", &Handler::depthToSpace, policy::move)
.def("reshape", &Handler::reshape, policy::move) .def("reshape", &Handler::reshape, policy::move)
.def("resize", &Handler::resize, policy::move) .def("resize", &Handler::resize, policy::move)
.def("squeeze", &Handler::squeeze, policy::move)
.def("unsqueeze", &Handler::unsqueeze, policy::move)
.def("concat", &Handler::concat, policy::move) .def("concat", &Handler::concat, policy::move)
.def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move)
.def("split", &Handler::split, policy::move) .def("split", &Handler::split, policy::move)

View File

@ -19,6 +19,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
"Reshape_CUDA_Int32"); "Reshape_CUDA_Int32");
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
"Flatten_CUDA_Float32"); "Flatten_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda,
"Squeeze_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda,
"Unsqueeze_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
"Identity_CUDA_Float32"); "Identity_CUDA_Float32");

60
src/operators/squeeze.cc Normal file
View File

@ -0,0 +1,60 @@
#include "operators/squeeze.h"
#include "utils/operator_utils.h"
namespace infini {
SqueezeObj::SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes)
: OperatorObj(OpType::Squeeze, {input}, {output}), axes(std::move(axes)) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> SqueezeObj::inferShape(const TensorVec &inputs) {
Shape inputDim = inputs[0]->getDims();
Shape outputShape;
auto rank = inputs[0]->getRank();
if (axes.size() == 0) {
for (int i = 0; i < (int)rank; ++i) {
if (inputDim[i] == 1) {
axes.emplace_back(i);
}
}
}
auto new_axes = axes;
std::transform(axes.begin(), axes.end(), new_axes.begin(),
[inputDim, rank](auto x) {
x = get_real_axis(x, rank);
IT_ASSERT(inputDim[x] == 1);
return x;
});
for (int i = 0; i < (int)rank; ++i) {
auto it = std::find(new_axes.begin(), new_axes.end(), i);
if (it == new_axes.end()) {
outputShape.emplace_back(inputDim[i]);
}
}
return {{outputShape}};
}
std::string SqueezeObj::toString() const {
std::ostringstream os;
os << "Squeeze[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "axes=" << vecToString(axes) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> SqueezeObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), axes.begin(), axes.end());
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> SqueezeObj::getOpAttrVector() const {
vector<int> ret = axes;
ret.emplace(ret.begin(), type.underlying());
return ret;
}
} // namespace infini

View File

@ -0,0 +1,52 @@
#include "operators/unsqueeze.h"
#include "utils/operator_utils.h"
namespace infini {
UnsqueezeObj::UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output,
Shape axes)
: OperatorObj(OpType::Unsqueeze, {input}, {output}), axes(std::move(axes)) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> UnsqueezeObj::inferShape(const TensorVec &inputs) {
Shape inputDim = inputs[0]->getDims();
auto rank = inputs[0]->getRank() + axes.size();
Shape outputShape(rank, -1);
for (size_t i = 0; i < axes.size(); ++i) {
axes[i] = get_real_axis(axes[i], rank);
IT_ASSERT(outputShape[axes[i]] == -1, "Axes have duplicate");
outputShape[axes[i]] = 1;
}
auto it = inputDim.begin();
for (size_t i = 0; i < outputShape.size(); ++i) {
if (outputShape[i] == -1) {
outputShape[i] = *it++;
}
}
return {{outputShape}};
}
std::string UnsqueezeObj::toString() const {
std::ostringstream os;
os << "Unsqueeze[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "axes=" << vecToString(axes) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> UnsqueezeObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), axes.begin(), axes.end());
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> UnsqueezeObj::getOpAttrVector() const {
vector<int> ret = axes;
ret.emplace(ret.begin(), type.underlying());
return ret;
}
} // namespace infini

View File

@ -2,6 +2,8 @@
#include "core/kernel.h" #include "core/kernel.h"
#include "core/runtime.h" #include "core/runtime.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include "operators/squeeze.h"
#include "operators/unsqueeze.h"
#include "test.h" #include "test.h"
@ -54,4 +56,36 @@ TEST(Identity, ShapeInference) {
} }
} }
TEST(Squeeze, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 1, 4}, DataType::Float32);
auto op = g->addOp<SqueezeObj>(i, nullptr, Shape{-2});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({1, 1, 3, 4}, DataType::Float32);
auto op = g->addOp<SqueezeObj>(i, nullptr, Shape{});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4}));
}
}
TEST(Unsqueeze, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
auto op = g->addOp<UnsqueezeObj>(i, nullptr, Shape{0, 1});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 2, 3, 4}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
auto op = g->addOp<UnsqueezeObj>(i, nullptr, Shape{-1, -2});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4, 1, 1}));
}
}
} // namespace infini } // namespace infini