forked from jiuyuan/InfiniTensor
解除前端对onnx infershape功能的依赖 (#206)
* feat: SqueezeOp lift the dependency of onnx infershape. * feat: UnsqueezeOp lift the dependency of onnx infershape. * feat: lift the dependency of onnx infershape * fix: fix Makefile off nccl
This commit is contained in:
parent
46e61a5bd4
commit
58993d4339
1
Makefile
1
Makefile
|
@ -29,7 +29,6 @@ CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||||
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN)
|
||||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||||
CMAKE_OPT += -DBUILD_DIST=ON
|
|
||||||
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
CMAKE_OPT += -DBUILD_NNET=$(NNET)
|
||||||
|
|
||||||
ifeq ($(INTELCPU), ON)
|
ifeq ($(INTELCPU), ON)
|
||||||
|
|
|
@ -71,6 +71,8 @@ class GraphHandlerObj {
|
||||||
vector<float> scales_, vector<float> roi_, string mode,
|
vector<float> scales_, vector<float> roi_, string mode,
|
||||||
string ratioPolicy, string nearestMode,
|
string ratioPolicy, string nearestMode,
|
||||||
string coordTransMode);
|
string coordTransMode);
|
||||||
|
Tensor squeeze(Tensor input, Tensor output, Shape axes);
|
||||||
|
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
|
||||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||||
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Remove single-dimensional entries from the shape of a tensor.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class SqueezeObj : public OperatorObj {
|
||||||
|
Shape axes;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Squeeze object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input The input tensor.
|
||||||
|
* @param output The output tensor.
|
||||||
|
* @param axes List of integers indicating the dimensions to squeeze.
|
||||||
|
*/
|
||||||
|
SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
|
||||||
|
OP_CLONE(SqueezeObj);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
inline Shape getAxes() const { return axes; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,38 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief nsert single-dimensional entries to the shape of an input tensor.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class UnsqueezeObj : public OperatorObj {
|
||||||
|
Shape axes;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Unsqueeze object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input The input tensor.
|
||||||
|
* @param output The output tensor.
|
||||||
|
* @param axes List of integers indicating the dimensions to be inserted.
|
||||||
|
*/
|
||||||
|
UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes);
|
||||||
|
OP_CLONE(UnsqueezeObj);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
inline Shape getAxes() const { return axes; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -52,10 +52,10 @@ class OnnxStub:
|
||||||
self.inputs: Dict[str, backend.Tensor] = {}
|
self.inputs: Dict[str, backend.Tensor] = {}
|
||||||
self.outputs: Dict[str, backend.Tensor] = {}
|
self.outputs: Dict[str, backend.Tensor] = {}
|
||||||
self.initializer: Dict[int, TensorProto] = {}
|
self.initializer: Dict[int, TensorProto] = {}
|
||||||
try:
|
# try:
|
||||||
model = infer_shapes(model)
|
# model = infer_shapes(model)
|
||||||
except:
|
# except:
|
||||||
warnings.warn("infer_shapes failed.")
|
# warnings.warn("infer_shapes failed.")
|
||||||
self.handler = backend.GraphHandler(runtime)
|
self.handler = backend.GraphHandler(runtime)
|
||||||
|
|
||||||
tensors: Dict[str, backend.Tensor] = dict()
|
tensors: Dict[str, backend.Tensor] = dict()
|
||||||
|
@ -135,7 +135,7 @@ class OnnxStub:
|
||||||
1,
|
1,
|
||||||
reduce(
|
reduce(
|
||||||
lambda acc, x: acc * x,
|
lambda acc, x: acc * x,
|
||||||
_search_shape(model, node.input[2]),
|
tensors[node.input[2]].shape(),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
|
@ -357,7 +357,7 @@ class OnnxStub:
|
||||||
ceil_mode,
|
ceil_mode,
|
||||||
)
|
)
|
||||||
elif node.op_type == "GlobalAveragePool":
|
elif node.op_type == "GlobalAveragePool":
|
||||||
[_, _, h, w] = _search_shape(model, node.input[0])
|
[_, _, h, w] = tensors[node.input[0]].shape()
|
||||||
tensors[node.output[0]] = self.handler.avgPool(
|
tensors[node.output[0]] = self.handler.avgPool(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
|
@ -595,35 +595,43 @@ class OnnxStub:
|
||||||
coordinate_transformation_mode,
|
coordinate_transformation_mode,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Squeeze":
|
elif node.op_type == "Squeeze":
|
||||||
input_shape = _search_shape(model, node.input[0])
|
axes = (
|
||||||
axes = set(
|
_parse_data(data[node.input[1]])
|
||||||
[int(i) for i in data[node.input[1]].int64_data]
|
|
||||||
if len(node.input) > 1
|
if len(node.input) > 1
|
||||||
else _parse_attribute(node, {"axes": None})["axes"]
|
else None
|
||||||
)
|
)
|
||||||
assert all(input_shape[d] == 1 for d in axes)
|
if axes is None:
|
||||||
output_shape = []
|
axes = next(
|
||||||
for i, x in enumerate(input_shape):
|
(
|
||||||
if i not in axes:
|
attr.ints
|
||||||
output_shape.append(x)
|
for attr in node.attribute
|
||||||
tensors[node.output[0]] = self.handler.reshape(
|
if attr.name == "axes"
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = self.handler.squeeze(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
output_shape,
|
axes,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Unsqueeze":
|
elif node.op_type == "Unsqueeze":
|
||||||
input_shape = _search_shape(model, node.input[0])
|
|
||||||
axes = (
|
axes = (
|
||||||
[int(i) for i in data[node.input[1]].int64_data]
|
_parse_data(data[node.input[1]])
|
||||||
if len(node.input) > 1
|
if len(node.input) > 1
|
||||||
else _parse_attribute(node, {"axes": None})["axes"]
|
else None
|
||||||
)
|
)
|
||||||
for i in axes:
|
if axes is None:
|
||||||
input_shape.insert(i, 1)
|
axes = next(
|
||||||
tensors[node.output[0]] = self.handler.reshape(
|
(
|
||||||
|
attr.ints
|
||||||
|
for attr in node.attribute
|
||||||
|
if attr.name == "axes"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = self.handler.unsqueeze(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
input_shape,
|
axes,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Concat":
|
elif node.op_type == "Concat":
|
||||||
tensors[node.output[0]] = self.handler.concat(
|
tensors[node.output[0]] = self.handler.concat(
|
||||||
|
@ -935,8 +943,7 @@ class OnnxStub:
|
||||||
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||||
)
|
)
|
||||||
(alpha, beta, bias, size) = (
|
(alpha, beta, bias, size) = (
|
||||||
attributes[name]
|
attributes[name] for name in ["alpha", "beta", "bias", "size"]
|
||||||
for name in ["alpha", "beta", "bias", "size"]
|
|
||||||
)
|
)
|
||||||
tensors[node.output[0]] = self.handler.lrn(
|
tensors[node.output[0]] = self.handler.lrn(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -1207,6 +1214,30 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
|
elif ty == backend.OpTypeId.Squeeze:
|
||||||
|
axes = backend.squeeze_axes_of(op)
|
||||||
|
inputs.append(
|
||||||
|
ctx.push_data_input(
|
||||||
|
name,
|
||||||
|
"axes",
|
||||||
|
TensorProto.INT64,
|
||||||
|
[len(axes)],
|
||||||
|
axes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
|
elif ty == backend.OpTypeId.Unsqueeze:
|
||||||
|
axes = backend.unsqueeze_axes_of(op)
|
||||||
|
inputs.append(
|
||||||
|
ctx.push_data_input(
|
||||||
|
name,
|
||||||
|
"axes",
|
||||||
|
TensorProto.INT64,
|
||||||
|
[len(axes)],
|
||||||
|
axes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||||
elif ty == backend.OpTypeId.Concat:
|
elif ty == backend.OpTypeId.Concat:
|
||||||
axis = backend.concat_axis_of(op)
|
axis = backend.concat_axis_of(op)
|
||||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||||
|
@ -1344,50 +1375,6 @@ def from_onnx(model: ModelProto, runtime):
|
||||||
return stub.inputs, stub.outputs, stub.handler
|
return stub.inputs, stub.outputs, stub.handler
|
||||||
|
|
||||||
|
|
||||||
def _search_shape(model: ModelProto, name: str) -> List[int]:
|
|
||||||
ans = (
|
|
||||||
next(
|
|
||||||
(
|
|
||||||
[
|
|
||||||
(d.dim_value if d.dim_value > 0 else 1)
|
|
||||||
for d in tensor.type.tensor_type.shape.dim
|
|
||||||
]
|
|
||||||
for tensor in model.graph.value_info
|
|
||||||
if tensor.name == name
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
or next(
|
|
||||||
(
|
|
||||||
[
|
|
||||||
(d.dim_value if d.dim_value > 0 else 1)
|
|
||||||
for d in tensor.type.tensor_type.shape.dim
|
|
||||||
]
|
|
||||||
for tensor in model.graph.input
|
|
||||||
if tensor.name == name
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
or next(
|
|
||||||
(
|
|
||||||
[
|
|
||||||
(d.dim_value if d.dim_value > 0 else 1)
|
|
||||||
for d in tensor.type.tensor_type.shape.dim
|
|
||||||
]
|
|
||||||
for tensor in model.graph.output
|
|
||||||
if tensor.name == name
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
or next(
|
|
||||||
[int(d) for d in tensor.dims]
|
|
||||||
for tensor in model.graph.initializer
|
|
||||||
if tensor.name == name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||||
for attr in node.attribute:
|
for attr in node.attribute:
|
||||||
if attr.type == AttributeProto.INT:
|
if attr.type == AttributeProto.INT:
|
||||||
|
|
|
@ -303,6 +303,28 @@ class TestStringMethods(unittest.TestCase):
|
||||||
reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize")
|
reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize")
|
||||||
make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales]))
|
make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales]))
|
||||||
|
|
||||||
|
def test_squeeze(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 5])
|
||||||
|
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
|
||||||
|
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [3, 5])
|
||||||
|
squeeze = make_node("Squeeze", ["input", "axes"], ["output"], name="squeeze")
|
||||||
|
make_and_import_model(
|
||||||
|
make_graph([squeeze], "squeeze", [input, axes], [output], [axes_data])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unsqueeze(self):
|
||||||
|
input = make_tensor_value_info("input", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||||
|
axes = make_tensor_value_info("axes", TensorProto.INT64, [2])
|
||||||
|
axes_data = make_tensor("axes", TensorProto.INT64, [2], [0, 2])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 3, 4, 5])
|
||||||
|
unsqueeze = make_node(
|
||||||
|
"Unsqueeze", ["input", "axes"], ["output"], name="unsqueeze"
|
||||||
|
)
|
||||||
|
make_and_import_model(
|
||||||
|
make_graph([unsqueeze], "unsqueeze", [input, axes], [output], [axes_data])
|
||||||
|
)
|
||||||
|
|
||||||
def test_concat(self):
|
def test_concat(self):
|
||||||
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
|
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||||
input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5])
|
input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5])
|
||||||
|
|
|
@ -22,8 +22,10 @@
|
||||||
#include "operators/slice.h"
|
#include "operators/slice.h"
|
||||||
#include "operators/softmax.h"
|
#include "operators/softmax.h"
|
||||||
#include "operators/split.h"
|
#include "operators/split.h"
|
||||||
|
#include "operators/squeeze.h"
|
||||||
#include "operators/transpose.h"
|
#include "operators/transpose.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
#include "operators/unsqueeze.h"
|
||||||
#include "operators/where.h"
|
#include "operators/where.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
@ -608,6 +610,28 @@ Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::squeeze(Tensor input, Tensor output, Shape axes) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<SqueezeObj>(std::move(input), output,
|
||||||
|
std::move(axes));
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<SqueezeObj>(std::move(input), output, std::move(axes))
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::unsqueeze(Tensor input, Tensor output, Shape axes) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<UnsqueezeObj>(std::move(input), output,
|
||||||
|
std::move(axes));
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<UnsqueezeObj>(std::move(input), output, std::move(axes))
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static CastType inferCastType(Tensor input, int to) {
|
static CastType inferCastType(Tensor input, int to) {
|
||||||
auto iType = input->getDType();
|
auto iType = input->getDType();
|
||||||
auto oType = DataType(to);
|
auto oType = DataType(to);
|
||||||
|
|
|
@ -12,8 +12,10 @@
|
||||||
#include "operators/reduce.h"
|
#include "operators/reduce.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
#include "operators/split.h"
|
#include "operators/split.h"
|
||||||
|
#include "operators/squeeze.h"
|
||||||
#include "operators/transpose.h"
|
#include "operators/transpose.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
#include "operators/unsqueeze.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <pybind11/numpy.h>
|
#include <pybind11/numpy.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
|
@ -93,6 +95,8 @@ void export_values(py::module &m) {
|
||||||
.VALUE(OpType, ReduceMean)
|
.VALUE(OpType, ReduceMean)
|
||||||
.VALUE(OpType, ReduceSum)
|
.VALUE(OpType, ReduceSum)
|
||||||
.VALUE(OpType, Reshape)
|
.VALUE(OpType, Reshape)
|
||||||
|
.VALUE(OpType, Squeeze)
|
||||||
|
.VALUE(OpType, Unsqueeze)
|
||||||
.VALUE(OpType, Flatten)
|
.VALUE(OpType, Flatten)
|
||||||
.VALUE(OpType, Identity)
|
.VALUE(OpType, Identity)
|
||||||
.VALUE(OpType, BatchNormalization)
|
.VALUE(OpType, BatchNormalization)
|
||||||
|
@ -256,6 +260,24 @@ static vector<int64_t> reshape_shape_of(Operator op) {
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static vector<int64_t> squeeze_axes_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Squeeze);
|
||||||
|
auto axes = dynamic_cast<const SqueezeObj *>(op.get())->getAxes();
|
||||||
|
vector<int64_t> ans(axes.size());
|
||||||
|
std::transform(axes.begin(), axes.end(), ans.begin(),
|
||||||
|
[](auto x) { return static_cast<int64_t>(x); });
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
static vector<int64_t> unsqueeze_axes_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::Unsqueeze);
|
||||||
|
auto axes = dynamic_cast<const UnsqueezeObj *>(op.get())->getAxes();
|
||||||
|
vector<int64_t> ans(axes.size());
|
||||||
|
std::transform(axes.begin(), axes.end(), ans.begin(),
|
||||||
|
[](auto x) { return static_cast<int64_t>(x); });
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
static vector<int64_t> expand_shape_of(Operator op) {
|
static vector<int64_t> expand_shape_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Expand);
|
IT_ASSERT(op->getOpType() == OpType::Expand);
|
||||||
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
|
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
|
||||||
|
@ -343,6 +365,8 @@ void export_functions(py::module &m) {
|
||||||
.FUNCTION(flatten_axis_of)
|
.FUNCTION(flatten_axis_of)
|
||||||
.FUNCTION(cast_to_of)
|
.FUNCTION(cast_to_of)
|
||||||
.FUNCTION(depth_to_space_attrs_of)
|
.FUNCTION(depth_to_space_attrs_of)
|
||||||
|
.FUNCTION(squeeze_axes_of)
|
||||||
|
.FUNCTION(unsqueeze_axes_of)
|
||||||
.FUNCTION(lrn_attrs_of);
|
.FUNCTION(lrn_attrs_of);
|
||||||
#undef FUNCTION
|
#undef FUNCTION
|
||||||
}
|
}
|
||||||
|
@ -509,6 +533,8 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("depthToSpace", &Handler::depthToSpace, policy::move)
|
.def("depthToSpace", &Handler::depthToSpace, policy::move)
|
||||||
.def("reshape", &Handler::reshape, policy::move)
|
.def("reshape", &Handler::reshape, policy::move)
|
||||||
.def("resize", &Handler::resize, policy::move)
|
.def("resize", &Handler::resize, policy::move)
|
||||||
|
.def("squeeze", &Handler::squeeze, policy::move)
|
||||||
|
.def("unsqueeze", &Handler::unsqueeze, policy::move)
|
||||||
.def("concat", &Handler::concat, policy::move)
|
.def("concat", &Handler::concat, policy::move)
|
||||||
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
||||||
.def("split", &Handler::split, policy::move)
|
.def("split", &Handler::split, policy::move)
|
||||||
|
|
|
@ -19,6 +19,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
|
||||||
"Reshape_CUDA_Int32");
|
"Reshape_CUDA_Int32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
||||||
"Flatten_CUDA_Float32");
|
"Flatten_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda,
|
||||||
|
"Squeeze_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda,
|
||||||
|
"Unsqueeze_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
||||||
"Identity_CUDA_Float32");
|
"Identity_CUDA_Float32");
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
#include "operators/squeeze.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
SqueezeObj::SqueezeObj(GraphObj *graph, Tensor input, Tensor output, Shape axes)
|
||||||
|
: OperatorObj(OpType::Squeeze, {input}, {output}), axes(std::move(axes)) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> SqueezeObj::inferShape(const TensorVec &inputs) {
|
||||||
|
Shape inputDim = inputs[0]->getDims();
|
||||||
|
Shape outputShape;
|
||||||
|
auto rank = inputs[0]->getRank();
|
||||||
|
if (axes.size() == 0) {
|
||||||
|
for (int i = 0; i < (int)rank; ++i) {
|
||||||
|
if (inputDim[i] == 1) {
|
||||||
|
axes.emplace_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto new_axes = axes;
|
||||||
|
std::transform(axes.begin(), axes.end(), new_axes.begin(),
|
||||||
|
[inputDim, rank](auto x) {
|
||||||
|
x = get_real_axis(x, rank);
|
||||||
|
IT_ASSERT(inputDim[x] == 1);
|
||||||
|
return x;
|
||||||
|
});
|
||||||
|
for (int i = 0; i < (int)rank; ++i) {
|
||||||
|
auto it = std::find(new_axes.begin(), new_axes.end(), i);
|
||||||
|
if (it == new_axes.end()) {
|
||||||
|
outputShape.emplace_back(inputDim[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {{outputShape}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SqueezeObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Squeeze[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "axes=" << vecToString(axes) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> SqueezeObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
vector<int> SqueezeObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = axes;
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,52 @@
|
||||||
|
#include "operators/unsqueeze.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
UnsqueezeObj::UnsqueezeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
Shape axes)
|
||||||
|
: OperatorObj(OpType::Unsqueeze, {input}, {output}), axes(std::move(axes)) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> UnsqueezeObj::inferShape(const TensorVec &inputs) {
|
||||||
|
Shape inputDim = inputs[0]->getDims();
|
||||||
|
auto rank = inputs[0]->getRank() + axes.size();
|
||||||
|
Shape outputShape(rank, -1);
|
||||||
|
for (size_t i = 0; i < axes.size(); ++i) {
|
||||||
|
axes[i] = get_real_axis(axes[i], rank);
|
||||||
|
IT_ASSERT(outputShape[axes[i]] == -1, "Axes have duplicate");
|
||||||
|
outputShape[axes[i]] = 1;
|
||||||
|
}
|
||||||
|
auto it = inputDim.begin();
|
||||||
|
for (size_t i = 0; i < outputShape.size(); ++i) {
|
||||||
|
if (outputShape[i] == -1) {
|
||||||
|
outputShape[i] = *it++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {{outputShape}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UnsqueezeObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Unsqueeze[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "axes=" << vecToString(axes) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> UnsqueezeObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), axes.begin(), axes.end());
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
vector<int> UnsqueezeObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = axes;
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -2,6 +2,8 @@
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
#include "operators/squeeze.h"
|
||||||
|
#include "operators/unsqueeze.h"
|
||||||
|
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
|
||||||
|
@ -54,4 +56,36 @@ TEST(Identity, ShapeInference) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Squeeze, ShapeInference) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 1, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<SqueezeObj>(i, nullptr, Shape{-2});
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({1, 1, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<SqueezeObj>(i, nullptr, Shape{});
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Unsqueeze, ShapeInference) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<UnsqueezeObj>(i, nullptr, Shape{0, 1});
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 2, 3, 4}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<UnsqueezeObj>(i, nullptr, Shape{-1, -2});
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 4, 1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue