forked from jiuyuan/InfiniTensor
解除前端对onnx infershape功能的依赖 (#206)
* feat: SqueezeOp lift the dependency of onnx infershape. * feat: UnsqueezeOp lift the dependency of onnx infershape. * feat: lift the dependency of onnx infershape * fix: fix Makefile off nccl
This commit is contained in:
parent
46e61a5bd4
commit
58993d4339
1
Makefile
1
Makefile
|
@ -29,7 +29,6 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
|
|||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||
CMAKE_OPT += -DBUILD_DIST=ON
|
||||
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
||||
|
||||
ifeq ($(INTELCPU), ON)
|
||||
|
|
|
@ -71,6 +71,8 @@ class GraphHandlerObj {
|
|||
vector<float> scales_, vector<float> roi_, string mode,
|
||||
string ratioPolicy, string nearestMode,
|
||||
string coordTransMode);
|
||||
Tensor squeeze(Tensor input, Tensor output, Shape axes);
|
||||
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
|
||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief Remove single-dimensional entries from the shape of a tensor.
|
||||
*
|
||||
*/
|
||||
class SqueezeObj : public OperatorObj {
|
||||
Shape axes;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Squeeze object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axes List of integers indicating the dimensions to squeeze.
|
||||
*/
|
||||
SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
|
||||
OP_CLONE(SqueezeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getAxes() const { return axes; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief nsert single-dimensional entries to the shape of an input tensor.
|
||||
*
|
||||
*/
|
||||
class UnsqueezeObj : public OperatorObj {
|
||||
Shape axes;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Unsqueeze object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axes List of integers indicating the dimensions to be inserted.
|
||||
*/
|
||||
UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
|
||||
OP_CLONE(UnsqueezeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getAxes() const { return axes; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -52,10 +52,10 @@ class OnnxStub:
|
|||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.initializer: Dict[int, TensorProto] = {}
|
||||
try:
|
||||
model = infer_shapes(model)
|
||||
except:
|
||||
warnings.warn("infer_shapes failed.")
|
||||
# try:
|
||||
# model = infer_shapes(model)
|
||||
# except:
|
||||
# warnings.warn("infer_shapes failed.")
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
|
@ -135,7 +135,7 @@ class OnnxStub:
|
|||
1,
|
||||
reduce(
|
||||
lambda acc, x: acc * x,
|
||||
_search_shape(model, node.input[2]),
|
||||
tensors[node.input[2]].shape(),
|
||||
),
|
||||
1,
|
||||
1,
|
||||
|
@ -357,7 +357,7 @@ class OnnxStub:
|
|||
ceil_mode,
|
||||
)
|
||||
elif node.op_type == "GlobalAveragePool":
|
||||
[_, _, h, w] = _search_shape(model, node.input[0])
|
||||
[_, _, h, w] = tensors[node.input[0]].shape()
|
||||
tensors[node.output[0]] = self.handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
|
@ -595,35 +595,43 @@ class OnnxStub:
|
|||
coordinate_transformation_mode,
|
||||
)
|
||||
elif node.op_type == "Squeeze":
|
||||
input_shape = _search_shape(model, node.input[0])
|
||||
axes = set(
|
||||
[int(i) for i in data[node.input[1]].int64_data]
|
||||
axes = (
|
||||
_parse_data(data[node.input[1]])
|
||||
if len(node.input) > 1
|
||||
else _parse_attribute(node, {"axes": None})["axes"]
|
||||
else None
|
||||
)
|
||||
assert all(input_shape[d] == 1 for d in axes)
|
||||
output_shape = []
|
||||
for i, x in enumerate(input_shape):
|
||||
if i not in axes:
|
||||
output_shape.append(x)
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
if axes is None:
|
||||
axes = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
),
|
||||
[],
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.squeeze(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
output_shape,
|
||||
axes,
|
||||
)
|
||||
elif node.op_type == "Unsqueeze":
|
||||
input_shape = _search_shape(model, node.input[0])
|
||||
axes = (
|
||||
[int(i) for i in data[node.input[1]].int64_data]
|
||||
_parse_data(data[node.input[1]])
|
||||
if len(node.input) > 1
|
||||
else _parse_attribute(node, {"axes": None})["axes"]
|
||||
else None
|
||||
)
|
||||
for i in axes:
|
||||
input_shape.insert(i, 1)
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
if axes is None:
|
||||
axes = next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
)
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.unsqueeze(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
input_shape,
|
||||
axes,
|
||||
)
|
||||
elif node.op_type == "Concat":
|
||||
tensors[node.output[0]] = self.handler.concat(
|
||||
|
@ -935,8 +943,7 @@ class OnnxStub:
|
|||
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||
)
|
||||
(alpha, beta, bias, size) = (
|
||||
attributes[name]
|
||||
for name in ["alpha", "beta", "bias", "size"]
|
||||
attributes[name] for name in ["alpha", "beta", "bias", "size"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.lrn(
|
||||
tensors[node.input[0]],
|
||||
|
@ -1207,6 +1214,30 @@ class OnnxStub:
|
|||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Squeeze:
|
||||
axes = backend.squeeze_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name,
|
||||
"axes",
|
||||
TensorProto.INT64,
|
||||
[len(axes)],
|
||||
axes,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Unsqueeze:
|
||||
axes = backend.unsqueeze_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(
|
||||
name,
|
||||
"axes",
|
||||
TensorProto.INT64,
|
||||
[len(axes)],
|
||||
axes,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Concat:
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
|
@ -1344,50 +1375,6 @@ def from_onnx(model: ModelProto, runtime):
|
|||
return stub.inputs, stub.outputs, stub.handler
|
||||
|
||||
|
||||
def _search_shape(model: ModelProto, name: str) -> List[int]:
|
||||
ans = (
|
||||
next(
|
||||
(
|
||||
[
|
||||
(d.dim_value if d.dim_value > 0 else 1)
|
||||
for d in tensor.type.tensor_type.shape.dim
|
||||
]
|
||||
for tensor in model.graph.value_info
|
||||
if tensor.name == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
or next(
|
||||
(
|
||||
[
|
||||
(d.dim_value if d.dim_value > 0 else 1)
|
||||
for d in tensor.type.tensor_type.shape.dim
|
||||
]
|
||||
for tensor in model.graph.input
|
||||
if tensor.name == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
or next(
|
||||
(
|
||||
[
|
||||
(d.dim_value if d.dim_value > 0 else 1)
|
||||
for d in tensor.type.tensor_type.shape.dim
|
||||
]
|
||||
for tensor in model.graph.output
|
||||
if tensor.name == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
or next(
|
||||
[int(d) for d in tensor.dims]
|
||||
for tensor in model.graph.initializer
|
||||
if tensor.name == name
|
||||
)
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.type == AttributeProto.INT:
|
||||
|
|
|
@ -303,6 +303,28 @@ class TestStringMethods(unittest.TestCase):
|
|||
reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize")
|
||||
make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales]))
|
||||
|
||||
def test_squeeze(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 5])
|
||||
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
|
||||
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [3, 5])
|
||||
squeeze = make_node("Squeeze", ["input", "axes"], ["output"], name="squeeze")
|
||||
make_and_import_model(
|
||||
make_graph([squeeze], "squeeze", [input, axes], [output], [axes_data])
|
||||
)
|
||||
|
||||
def test_unsqueeze(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
|
||||
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 3, 4, 5])
|
||||
unsqueeze = make_node(
|
||||
"Unsqueeze", ["input", "axes"], ["output"], name="unsqueeze"
|
||||
)
|
||||
make_and_import_model(
|
||||
make_graph([unsqueeze], "unsqueeze", [input, axes], [output], [axes_data])
|
||||
)
|
||||
|
||||
def test_concat(self):
|
||||
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5])
|
||||
|
|
|
@ -22,8 +22,10 @@
|
|||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
#include "operators/where.h"
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
@ -608,6 +610,28 @@ Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::squeeze(Tensor input, Tensor output, Shape axes) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<SqueezeObj>(std::move(input), output,
|
||||
std::move(axes));
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<SqueezeObj>(std::move(input), output, std::move(axes))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::unsqueeze(Tensor input, Tensor output, Shape axes) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<UnsqueezeObj>(std::move(input), output,
|
||||
std::move(axes));
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<UnsqueezeObj>(std::move(input), output, std::move(axes))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
static CastType inferCastType(Tensor input, int to) {
|
||||
auto iType = input->getDType();
|
||||
auto oType = DataType(to);
|
||||
|
|
|
@ -12,8 +12,10 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
#include <algorithm>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
@ -93,6 +95,8 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, ReduceMean)
|
||||
.VALUE(OpType, ReduceSum)
|
||||
.VALUE(OpType, Reshape)
|
||||
.VALUE(OpType, Squeeze)
|
||||
.VALUE(OpType, Unsqueeze)
|
||||
.VALUE(OpType, Flatten)
|
||||
.VALUE(OpType, Identity)
|
||||
.VALUE(OpType, BatchNormalization)
|
||||
|
@ -256,6 +260,24 @@ static vector<int64_t> reshape_shape_of(Operator op) {
|
|||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> squeeze_axes_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Squeeze);
|
||||
auto axes = dynamic_cast<const SqueezeObj *>(op.get())->getAxes();
|
||||
vector<int64_t> ans(axes.size());
|
||||
std::transform(axes.begin(), axes.end(), ans.begin(),
|
||||
[](auto x) { return static_cast<int64_t>(x); });
|
||||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> unsqueeze_axes_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Unsqueeze);
|
||||
auto axes = dynamic_cast<const UnsqueezeObj *>(op.get())->getAxes();
|
||||
vector<int64_t> ans(axes.size());
|
||||
std::transform(axes.begin(), axes.end(), ans.begin(),
|
||||
[](auto x) { return static_cast<int64_t>(x); });
|
||||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> expand_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Expand);
|
||||
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
|
||||
|
@ -343,6 +365,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(flatten_axis_of)
|
||||
.FUNCTION(cast_to_of)
|
||||
.FUNCTION(depth_to_space_attrs_of)
|
||||
.FUNCTION(squeeze_axes_of)
|
||||
.FUNCTION(unsqueeze_axes_of)
|
||||
.FUNCTION(lrn_attrs_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
@ -509,6 +533,8 @@ void init_graph_builder(py::module &m) {
|
|||
.def("depthToSpace", &Handler::depthToSpace, policy::move)
|
||||
.def("reshape", &Handler::reshape, policy::move)
|
||||
.def("resize", &Handler::resize, policy::move)
|
||||
.def("squeeze", &Handler::squeeze, policy::move)
|
||||
.def("unsqueeze", &Handler::unsqueeze, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
||||
.def("split", &Handler::split, policy::move)
|
||||
|
|
|
@ -19,6 +19,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
|
|||
"Reshape_CUDA_Int32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
||||
"Flatten_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda,
|
||||
"Squeeze_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda,
|
||||
"Unsqueeze_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
||||
"Identity_CUDA_Float32");
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
#include "operators/squeeze.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
SqueezeObj::SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes)
|
||||
: OperatorObj(OpType::Squeeze, {input}, {output}), axes(std::move(axes)) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> SqueezeObj::inferShape(const TensorVec &inputs) {
|
||||
Shape inputDim = inputs[0]->getDims();
|
||||
Shape outputShape;
|
||||
auto rank = inputs[0]->getRank();
|
||||
if (axes.size() == 0) {
|
||||
for (int i = 0; i < (int)rank; ++i) {
|
||||
if (inputDim[i] == 1) {
|
||||
axes.emplace_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto new_axes = axes;
|
||||
std::transform(axes.begin(), axes.end(), new_axes.begin(),
|
||||
[inputDim, rank](auto x) {
|
||||
x = get_real_axis(x, rank);
|
||||
IT_ASSERT(inputDim[x] == 1);
|
||||
return x;
|
||||
});
|
||||
for (int i = 0; i < (int)rank; ++i) {
|
||||
auto it = std::find(new_axes.begin(), new_axes.end(), i);
|
||||
if (it == new_axes.end()) {
|
||||
outputShape.emplace_back(inputDim[i]);
|
||||
}
|
||||
}
|
||||
return {{outputShape}};
|
||||
}
|
||||
|
||||
std::string SqueezeObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Squeeze[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "axes=" << vecToString(axes) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> SqueezeObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
vector<int> SqueezeObj::getOpAttrVector() const {
|
||||
vector<int> ret = axes;
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,52 @@
|
|||
#include "operators/unsqueeze.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
UnsqueezeObj::UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
Shape axes)
|
||||
: OperatorObj(OpType::Unsqueeze, {input}, {output}), axes(std::move(axes)) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> UnsqueezeObj::inferShape(const TensorVec &inputs) {
|
||||
Shape inputDim = inputs[0]->getDims();
|
||||
auto rank = inputs[0]->getRank() + axes.size();
|
||||
Shape outputShape(rank, -1);
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
axes[i] = get_real_axis(axes[i], rank);
|
||||
IT_ASSERT(outputShape[axes[i]] == -1, "Axes have duplicate");
|
||||
outputShape[axes[i]] = 1;
|
||||
}
|
||||
auto it = inputDim.begin();
|
||||
for (size_t i = 0; i < outputShape.size(); ++i) {
|
||||
if (outputShape[i] == -1) {
|
||||
outputShape[i] = *it++;
|
||||
}
|
||||
}
|
||||
return {{outputShape}};
|
||||
}
|
||||
|
||||
std::string UnsqueezeObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Unsqueeze[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "axes=" << vecToString(axes) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> UnsqueezeObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
vector<int> UnsqueezeObj::getOpAttrVector() const {
|
||||
vector<int> ret = axes;
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -2,6 +2,8 @@
|
|||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/squeeze.h"
|
||||
#include "operators/unsqueeze.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
|
@ -54,4 +56,36 @@ TEST(Identity, ShapeInference) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(Squeeze, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 1, 4}, DataType::Float32);
|
||||
auto op = g->addOp<SqueezeObj>(i, nullptr, Shape{-2});
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({1, 1, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<SqueezeObj>(i, nullptr, Shape{});
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Unsqueeze, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<UnsqueezeObj>(i, nullptr, Shape{0, 1});
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 2, 3, 4}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<UnsqueezeObj>(i, nullptr, Shape{-1, -2});
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue