diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index caec8ca2..313a1f79 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -65,6 +65,12 @@ class GraphHandlerObj { std::optional max); Tensor transpose(Tensor data, Tensor transposed, Shape perm); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); + Tensor resize(Tensor input, Tensor output, + const std::optional> &axes, Tensor sizes, + Tensor scales, Tensor roi, vector sizes_, + vector scales_, vector roi_, string mode, + string ratioPolicy, string nearestMode, + string coordTransMode); Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, diff --git a/include/operators/resize.h b/include/operators/resize.h index 96283c12..220ef719 100644 --- a/include/operators/resize.h +++ b/include/operators/resize.h @@ -27,6 +27,60 @@ class ResizeObj : public OperatorObj { enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller, none }; enum class ECoeffMode { nearest, linear, cubic }; + static ECoordinateTransMode fromECoordinateTransModeStr(string mode) { + if (mode == "half_pixel") { + return ECoordinateTransMode::halfPixel; + } else if (mode == "asymmetric") { + return ECoordinateTransMode::asymmetric; + } else if (mode == "align_corners") { + return ECoordinateTransMode::alignCorners; + } else if (mode == "pytorch_half_pixel") { + return ECoordinateTransMode::pytorchHalfPixel; + } else if (mode == "tf_crop_and_resize") { + return ECoordinateTransMode::tfCropAndResize; + } else { + IT_TODO_HALT(); + } + } + + static ENearestMode fromENearestModeStr(string mode) { + if (mode == "round_prefer_floor") { + return ENearestMode::roundPreferFloor; + } else if (mode == "round_prefer_ceil") { + return ENearestMode::roundPreferCeil; + } else if (mode == "floor") { + return ENearestMode::floor; + } else if (mode == "ceil") { + return ENearestMode::ceil; + } else { + return ENearestMode::none; + } + } + + static EKeepAspectRatioPolicy fromRatioPolicyStr(string ratioPolicyStr) { + if (ratioPolicyStr == "stretch") { + return EKeepAspectRatioPolicy::stretch; + } else if (ratioPolicyStr == "not_larger") { + return EKeepAspectRatioPolicy::notLarger; + } else if (ratioPolicyStr == "not_smaller") { + return EKeepAspectRatioPolicy::notSmaller; + } else { + return EKeepAspectRatioPolicy::none; + } + } + + static ECoeffMode fromECoeffModeStr(string mode) { + if (mode == "nearest") { + return ECoeffMode::nearest; + } else if (mode == "linear") { + return ECoeffMode::linear; + } else if (mode == "cubic") { + return ECoeffMode::cubic; + } else { + IT_TODO_HALT(); + } + } + private: vector axes; vector scales; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 80457ecd..c63746af 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -535,6 +535,65 @@ class OnnxStub: tensors.get(node.output[0]), shape, ) + elif node.op_type == "Resize": + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, + { + "antialias": 0, + "axes": None, + "coordinate_transformation_mode": "half_pixel", + "cubic_coeff_a": -0.75, + "exclude_outside": 0, + "extrapolation_value": 0.0, + "keep_aspect_ratio_policy": "none", + "mode": "nearest", + "nearest_mode": "none", + }, + ) + ( + axes, + keep_aspect_ratio_policy, + coordinate_transformation_mode, + mode, + nearest_mode, + ) = ( + attributes[name] + for name in [ + "axes", + "keep_aspect_ratio_policy", + "coordinate_transformation_mode", + "mode", + "nearest_mode", + ] + ) + if len(node.input) > 1: + roiVal = _parse_data(data[node.input[1]]) + else: + roiVal = [] + if len(node.input) > 2: + scalesVal = _parse_data(data[node.input[2]]) + else: + scalesVal = [] + if len(node.input) > 3: + sizesVal = _parse_data(data[node.input[3]]) + else: + sizesVal = [] + tensors[node.output[0]] = self.handler.resize( + tensors[node.input[0]], + output, + axes, + tensors[node.input[3]] if len(node.input) > 3 else None, + tensors[node.input[2]] if len(node.input) > 2 else None, + tensors[node.input[1]] if len(node.input) > 1 else None, + sizesVal, + scalesVal, + roiVal, + mode, + keep_aspect_ratio_policy, + nearest_mode, + coordinate_transformation_mode, + ) elif node.op_type == "Squeeze": input_shape = _search_shape(model, node.input[0]) axes = set( diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 02911b50..f5d5a426 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -295,6 +295,14 @@ class TestStringMethods(unittest.TestCase): make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data]) ) + def test_resize(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 128, 40, 40]) + roi = make_tensor("roi", TensorProto.FLOAT, [0], []) + scales = make_tensor("scales", TensorProto.FLOAT, [4], [1, 1, 2, 2]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 128, 80, 80]) + reshape = make_node("Resize", ["x", "roi", "scales"], ["y"], name="resize") + make_and_import_model(make_graph([reshape], "resize", [x], [y], [roi, scales])) + def test_concat(self): input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4]) input2 = make_tensor_value_info("input2", TensorProto.FLOAT, [1, 3, 2, 5]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index e323538a..7fc6f977 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -17,6 +17,7 @@ #include "operators/recv.h" #include "operators/reduce.h" #include "operators/reshape.h" +#include "operators/resize.h" #include "operators/send.h" #include "operators/slice.h" #include "operators/softmax.h" @@ -254,6 +255,64 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { } } +Tensor GraphHandlerObj::resize(Tensor input, Tensor output, + const std::optional> &axes, + Tensor sizes, Tensor scales, Tensor roi, + vector sizes_, vector scales_, + vector roi_, string mode, + string ratioPolicy, string nearestMode, + string coordTransMode) { + if (sizes_.size() > 0) { + sizes->dataMalloc(); + sizes->copyin(sizes_); + } + if (scales_.size() > 0) { + scales->dataMalloc(); + scales->copyin(scales_); + } + if (roi_.size() > 0) { + roi->dataMalloc(); + roi->copyin(roi_); + } + ResizeObj::EKeepAspectRatioPolicy ratioPolicy_ = + ResizeObj::fromRatioPolicyStr(ratioPolicy); + ResizeObj::ENearestMode nearestMode_ = + ResizeObj::fromENearestModeStr(nearestMode); + ResizeObj::ECoordinateTransMode coordTransMode_ = + ResizeObj::fromECoordinateTransModeStr(coordTransMode); + ResizeObj::ECoeffMode mode_ = ResizeObj::fromECoeffModeStr(mode); + if (output) { + if (mode == "nearest") { + g->addOpWithOutputs( + std::move(input), output, std::move(axes), std::move(sizes), + std::move(scales), std::move(roi), ratioPolicy_, nearestMode_, + coordTransMode_); + } else { + g->addOpWithOutputs( + std::move(input), output, std::move(axes), std::move(sizes), + std::move(scales), std::move(roi), mode_, ratioPolicy_, + coordTransMode_); + } + return output; + } else { + if (mode == "nearest") { + return g + ->addOp(std::move(input), output, std::move(axes), + std::move(sizes), std::move(scales), + std::move(roi), ratioPolicy_, nearestMode_, + coordTransMode_) + ->getOutput(); + } else { + return g + ->addOp(std::move(input), output, std::move(axes), + std::move(sizes), std::move(scales), + std::move(roi), mode_, ratioPolicy_, + coordTransMode_) + ->getOutput(); + } + } +} + Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { if (output) { g->addOpWithOutputs(std::move(inputs), output, dim); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 99b18172..c23009b5 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -506,6 +506,7 @@ void init_graph_builder(py::module &m) { .def("transpose", &Handler::transpose, policy::move) .def("depthToSpace", &Handler::depthToSpace, policy::move) .def("reshape", &Handler::reshape, policy::move) + .def("resize", &Handler::resize, policy::move) .def("concat", &Handler::concat, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("split", &Handler::split, policy::move)