feat: 前端支持 pad 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-15 11:41:06 +08:00
parent 7893ae0cca
commit 315763a83a
7 changed files with 51 additions and 10 deletions

View File

@ -70,6 +70,8 @@ class GraphHandlerObj {
Tensor slice(Tensor input, Tensor output, const vector<int> &starts, Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
const vector<int> &ends, const optional<vector<int>> &axes, const vector<int> &ends, const optional<vector<int>> &axes,
const optional<vector<int>> &steps); const optional<vector<int>> &steps);
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes);
}; };
} // namespace infini } // namespace infini

View File

@ -21,10 +21,10 @@ class PadObj : public OperatorObj {
* @param pads Add padding elements at the begining and end of each axis. * @param pads Add padding elements at the begining and end of each axis.
* Suppose that padding axes are [x1, x2, ...], then pads's format is * Suppose that padding axes are [x1, x2, ...], then pads's format is
* [x1_begin, x2_begin, ..., x1_end, x2_end, ...] * [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
* @param axis Pad for appointed axes. If axis is empty, pad for all axes. * @param axes Pad for appointed axes. If axis is empty, pad for all axes.
*/ */
PadObj(GraphObj *graph, Tensor input, Tensor output, PadObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &pads, const optional<const vector<int>> &axis); const vector<int> &pads, const optional<vector<int>> &axes);
OP_CLONE(PadObj); OP_CLONE(PadObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -202,6 +202,13 @@ def from_onnx(model: onnx.ModelProto):
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None, _parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None, _parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
) )
elif node.op_type == "Pad":
tensors[node.output[0]] = handler.pad(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
)
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))

View File

@ -215,9 +215,7 @@ class TestStringMethods(unittest.TestCase):
starts_data = make_tensor("starts", TensorProto.INT64, [4], [2, 10, 1, 5]) starts_data = make_tensor("starts", TensorProto.INT64, [4], [2, 10, 1, 5])
ends = make_tensor_value_info("ends", TensorProto.INT64, [4]) ends = make_tensor_value_info("ends", TensorProto.INT64, [4])
ends_data = make_tensor("ends", TensorProto.INT64, [4], [3, 10, 100, 100]) ends_data = make_tensor("ends", TensorProto.INT64, [4], [3, 10, 100, 100])
slice = make_node( slice = make_node("Slice", ["data", "starts", "ends"], ["output"], name="slice")
"Slice", ["data", "starts", "ends"], ["output"], name="gather"
)
make_and_import_model( make_and_import_model(
make_graph( make_graph(
[slice], [slice],
@ -228,6 +226,24 @@ class TestStringMethods(unittest.TestCase):
) )
) )
def test_pad(self):
data = make_tensor_value_info("data", TensorProto.UINT32, [1, 64, 162, 162])
output = make_tensor_value_info("output", TensorProto.UINT32, [3, 84, 164, 172])
pads = make_tensor_value_info("pads", TensorProto.INT64, [8])
pads_data = make_tensor(
"pads", TensorProto.INT64, [8], [2, 10, 1, 5, 0, 10, 1, 5]
)
pad = make_node("Pad", ["data", "pads"], ["output"], name="pad")
make_and_import_model(
make_graph(
[pad],
"pad",
[data, pads],
[output],
[pads_data],
)
)
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression> # see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self): def test_linear(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])

View File

@ -4,6 +4,7 @@
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h" #include "operators/pooling.h"
#include "operators/reduce_mean.h" #include "operators/reduce_mean.h"
#include "operators/reshape.h" #include "operators/reshape.h"
@ -180,6 +181,18 @@ Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
} }
} }
Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
const vector<int> &pads,
const optional<vector<int>> &axes) {
if (output) {
g->addOpWithOutputs<PadObj>(std::move(input), output, pads, axes);
return output;
} else {
return g->addOp<PadObj>(std::move(input), output, pads, axes)
->getOutput();
}
}
static DataType dtype_repr_convert(int dtype) { static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) { switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT: case OnnxDType::FLOAT:

View File

@ -100,6 +100,10 @@ void init_graph_builder(py::module &m) {
Tensor, Tensor, const vector<int> &, const vector<int> &, Tensor, Tensor, const vector<int> &, const vector<int> &,
const optional<vector<int>> &, const optional<vector<int>> &>( const optional<vector<int>> &, const optional<vector<int>> &>(
&Handler::slice), &Handler::slice),
policy::move)
.def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<vector<int>> &>(&Handler::pad),
policy::move); policy::move);
} }

View File

@ -2,19 +2,18 @@
namespace infini { namespace infini {
PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output, PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &_pads, const vector<int> &_pads, const optional<vector<int>> &axes)
const optional<const vector<int>> &axis)
: OperatorObj(OpType::Pad, {input}, {output}) { : OperatorObj(OpType::Pad, {input}, {output}) {
if (!axis) if (!axes)
pads = _pads; pads = _pads;
else { else {
auto nAxis = (*axis).size(); auto nAxis = (*axes).size();
IT_ASSERT(_pads.size() == nAxis * 2); IT_ASSERT(_pads.size() == nAxis * 2);
auto nDims = input->getDims().size(); auto nDims = input->getDims().size();
pads = vector<int>(nDims * 2, 0); pads = vector<int>(nDims * 2, 0);
for (size_t i = 0; i < nAxis; ++i) { for (size_t i = 0; i < nAxis; ++i) {
auto j = (*axis)[i]; auto j = (*axes)[i];
pads[j] = _pads[i]; pads[j] = _pads[i];
pads[j + nDims] = _pads[i + nAxis]; pads[j + nDims] = _pads[i + nAxis];
} }