diff --git a/include/cuda/resize.cuh b/include/cuda/resize.cuh new file mode 100644 index 00000000..69e03cdc --- /dev/null +++ b/include/cuda/resize.cuh @@ -0,0 +1,17 @@ +#pragma once +#include "cuda/cuda_common.h" + +typedef struct { + int nDims; + int oDims[4]; + int inDims[4]; + int inStride[4]; + float scale[4]; +} MetaData; + +namespace infini { +void resize_kernel_nearest(float *in, float *out, const MetaData &metaData, + size_t num, int coordinateMode, int nearestMode); +void resize_kernel_linear(float *in, float *out, const MetaData &metaData, + size_t num, int coordinateMode); +} // namespace infini diff --git a/include/operators/resize.h b/include/operators/resize.h new file mode 100644 index 00000000..e9f5a357 --- /dev/null +++ b/include/operators/resize.h @@ -0,0 +1,83 @@ +#pragma once + +#include "core/operator.h" + +namespace infini { +class ResizeObj : public OperatorObj { + public: + enum class ECoordinateTransMode { + halfPixel, + pytorchHalfPixel, + alignCorners, + asymmetric, + tfCropAndResize + }; + enum class ENearestMode { roundPreferFloor, roundPreferCeil, floor, ceil }; + enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller }; + enum class ECoeffMode { nearest, linear, cubic }; + + private: + vector axes; + vector scales; + ECoordinateTransMode coMode; // compute src coordinate from dst coordinate + ECoeffMode mode; // coeff mode,for computing dst value from coordinate src + // neighborhood . + ENearestMode nearestMode; // used in "nearest" mode, indicates how to get + // "nearest" pixel + EKeepAspectRatioPolicy + ratioPolicy; // used for computing shape when using "sizes" + + public: + // nearest mode, not tf_crop_and_resize + ResizeObj( + GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor sizes, + EKeepAspectRatioPolicy ratioPolicy, + ENearestMode nearestMode = ENearestMode::roundPreferFloor, + ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel); + ResizeObj( + GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor scales, + ENearestMode nearestMode = ENearestMode::roundPreferFloor, + ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel); + + // linear mode + ResizeObj( + GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor sizes, + EKeepAspectRatioPolicy ratioPolicy, ECoeffMode mode, + ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel); + ResizeObj( + GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor scales, ECoeffMode mode, + ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel); + + vector inferDataType(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) const override; + std::string toString() const override; + int numInputs() const override { return 4; } + int numOutputs() const override { return 1; } + + ECoeffMode getMode() const { return mode; } + int getNearestMode() const { return enum_to_underlying(nearestMode); } + int getKeepAxesRatioPolicy() const { + return enum_to_underlying(ratioPolicy); + } + int getCoordinateTransMode() const { return enum_to_underlying(coMode); } + float getScale(int i) const { + IT_ASSERT((size_t)i < scales.size()); + return scales[i]; + } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + + float round_int(float x) const; + bool checkCoordinateTransValid(int resizedCo, int origiCo) const; + void InitBySizes(Tensor input, Tensor sizes, + const std::optional> &axes); + void InitByScales(Tensor input, Tensor sizes, + const std::optional> &axes); +}; +} // namespace infini diff --git a/src/kernels/cuda/resize.cc b/src/kernels/cuda/resize.cc new file mode 100644 index 00000000..9eda2824 --- /dev/null +++ b/src/kernels/cuda/resize.cc @@ -0,0 +1,47 @@ +#include "operators/resize.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/resize.cuh" +namespace infini { +class ResizeCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto in = op->getInputs(0); + auto out = op->getOutputs()[0]; + + int nDims = in->getDims().size(); + if (nDims > 4) + IT_TODO_HALT(); + + MetaData metaData; + memset(&metaData, 0, sizeof(metaData)); + metaData.nDims = nDims; + for (int i = 0; i < nDims; ++i) { + metaData.inDims[i] = in->getDims()[i]; + metaData.oDims[i] = out->getDims()[i]; + metaData.inStride[i] = in->getStride()[i]; + metaData.scale[i] = op->getScale(i); + } + + switch (op->getMode()) { + case ResizeObj::ECoeffMode::nearest: + resize_kernel_nearest(in->getRawDataPtr(), + out->getRawDataPtr(), metaData, + out->size(), op->getCoordinateTransMode(), + op->getNearestMode()); + break; + case ResizeObj::ECoeffMode::linear: + resize_kernel_linear(in->getRawDataPtr(), + out->getRawDataPtr(), metaData, + out->size(), op->getCoordinateTransMode()); + break; + default: + IT_TODO_HALT(); + } + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Resize, DataType::Float32, ResizeCuda, + "Resize_CUDA_Float32"); + +} // namespace infini diff --git a/src/kernels/cuda/resize.cu b/src/kernels/cuda/resize.cu new file mode 100644 index 00000000..8086d042 --- /dev/null +++ b/src/kernels/cuda/resize.cu @@ -0,0 +1,208 @@ +#include "cmath" +#include "cuda/cuda_common.h" +#include "cuda/resize.cuh" +#include + +#ifndef GPU_LAMBDA +#define GPU_LAMBDA __device__ +#endif + +// nearest mode +__device__ int round_prefer_ceil(float x) { + return (x > 0.0) ? floor(x + 0.5) : ceil(x - 0.5); +} + +__device__ int round_prefer_floor(float x) { + return (x > 0.0) ? floor(x + 0.4) : ceil(x - 0.4); +} + +__device__ int prefer_floor(float x) { return std::floor(x); } + +__device__ int prefer_ceil(float x) { return std::ceil(x); } + +// coordinate transform mode +__device__ float half_pixel(int idx, float scale, int, int) { + return (idx + 0.5) / scale - 0.5; +} + +__device__ float pytorch_half_pixel(int idx, float scale, int length_resized, + int) { + return length_resized > 1 ? (idx + 0.5) / scale - 0.5 : 0; +} + +__device__ float align_corners(int idx, float scale, int length_resized, + int length_original) { + if (length_resized == 1) + return 0; + return (float)idx * (float)(length_original - 1) / + (float)(length_resized - 1); +} + +__device__ float asymmetric(int idx, float scale, int length_resized, + int length_original) { + return idx / scale; +} +/* +__device__ float tf_crop_and_resize(int idx, float scale, int length_resized, + int length_original) { + +}*/ + +// ATTENTION:The order of device functions in array must be consistent with the +// order in the enums of ResizeObj. +using nearest_mod_func_t = int (*)(float); +__device__ nearest_mod_func_t p_nearest_mode_fun[] = { + round_prefer_floor, round_prefer_ceil, prefer_floor, prefer_ceil}; + +using coordinate_trans_mod_func_t = float (*)(int idxO, float scale, int lenO, + int lenR); +__device__ coordinate_trans_mod_func_t p_cooridnate_trans_mode_func[] = { + half_pixel, pytorch_half_pixel, align_corners, asymmetric}; + +template +__device__ int nearestCoordinateTrans(int dOffset, MetaData metaData, + T1 transModeFun, T2 nearestModeFun) { + int sOffset = 0; + for (int i = metaData.nDims - 1; i >= 0; --i) { + int dIdx = dOffset % metaData.oDims[i]; + dOffset = dOffset / metaData.oDims[i]; + + if (metaData.inDims[i] == metaData.oDims[i]) + sOffset += dIdx * metaData.inStride[i]; + else { + float scale = (float)metaData.oDims[i] / (float)metaData.inDims[i]; + int sIdx = nearestModeFun(transModeFun( + dIdx, scale, metaData.oDims[i], metaData.inDims[i])); + if (sIdx > metaData.inDims[i] - 1) + sIdx = metaData.inDims[i] - 1; + else if (sIdx < 0) + sIdx = 0; + sOffset += sIdx * metaData.inStride[i]; + } + } + return sOffset; +} + +__global__ void _resize_kernel_nearest(float *in, float *out, MetaData metaData, + size_t num, int coordinateMode, + int nearestMode) { + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto stride = blockDim.x * gridDim.x; + + while (tid < num) { + int offset = nearestCoordinateTrans( + tid, metaData, p_cooridnate_trans_mode_func[coordinateMode], + p_nearest_mode_fun[nearestMode]); + out[tid] = in[offset]; + tid += stride; + } +} + +// ATTENTION: Make sure dim <=4 +typedef struct { + int offset[16]; + float power[16]; +} NeighborList; + +int __device__ getLimitIdx(int idx, int limit) { + if (idx < 0) + return 0; + if (idx > limit) + return limit; + return idx; +} + +__global__ void _resize_kernel_linear(float *in, float *out, MetaData metaData, + size_t num, int coordinateMode) { + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto stride = blockDim.x * gridDim.x; + + while (tid < num) { + auto dOffset = tid; + auto neighborNum = 0; + NeighborList neighborList; + memset(&neighborList, 0, sizeof(neighborList)); + for (int i = metaData.nDims - 1; i >= 0; --i) { + int dIdx = dOffset % metaData.oDims[i]; + float scale = metaData.scale[i]; + float sIdx = p_cooridnate_trans_mode_func[coordinateMode]( + dIdx, scale, scale * metaData.inDims[i], metaData.inDims[i]); + + int idx = std::floor(sIdx); + float power = 1 - (sIdx - idx); + + // update neighborList + if (metaData.inDims[i] == 1) { + if (neighborNum == 0) { + neighborList.offset[0] = 0; + neighborList.power[0] = power; + neighborNum = 1; + } else { + for (int j = 0; j < neighborNum; j++) { + neighborList.power[j] *= power; + } + } + } else { + if (neighborNum == 0) { + neighborList.offset[0] = + getLimitIdx(idx, metaData.inDims[i] - 1) * + metaData.inStride[i]; + neighborList.power[0] = power; + neighborList.offset[1] = + getLimitIdx(idx + 1, metaData.inDims[i] - 1) * + metaData.inStride[i]; + neighborList.power[1] = 1 - power; + neighborNum = 2; + } else { + for (int j = 0; j < neighborNum; j++) { + neighborList.offset[j + neighborNum] = + neighborList.offset[j] + + getLimitIdx(idx + 1, metaData.inDims[i] - 1) * + metaData.inStride[i]; + neighborList.power[j + neighborNum] = + (neighborList.power[j]) * (1 - power); + + neighborList.offset[j] += + getLimitIdx(idx, metaData.inDims[i] - 1) * + metaData.inStride[i]; + neighborList.power[j] *= power; + } + neighborNum *= 2; + } + } + + dOffset = dOffset / metaData.oDims[i]; + } + + float val = 0; + for (int i = 0; i < neighborNum; ++i) { + val += in[neighborList.offset[i]] * neighborList.power[i]; + } + out[tid] = val; + tid += stride; + } +} + +namespace infini { +void resize_kernel_nearest(float *in, float *out, const MetaData &metaData, + size_t num, int coordinateMode, int nearestMode) { + int blocksize = 32 * 16; + auto gridsize = (num + blocksize - 1) / blocksize; + IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / + sizeof(p_cooridnate_trans_mode_func[0])); + IT_ASSERT(nearestMode < + sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0])); + _resize_kernel_nearest<<>>( + in, out, metaData, num, coordinateMode, nearestMode); +} + +void resize_kernel_linear(float *in, float *out, const MetaData &metaData, + size_t num, int coordinateMode) { + int blocksize = 32 * 16; + auto gridsize = (num + blocksize - 1) / blocksize; + IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / + sizeof(p_cooridnate_trans_mode_func[0])); + _resize_kernel_linear<<>>(in, out, metaData, num, + coordinateMode); +} +} // namespace infini diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 95e88126..a5bf9d1c 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -69,6 +69,7 @@ std::string GatherObj::toString() const { os << "output=" << outputs[0]->getGuid() << ")"; return os.str(); } + vector GatherObj::getWorkloadVector() const { vector ret = inputs[0]->getDims(); ret.emplace(ret.begin(), enum_to_underlying(type)); diff --git a/src/operators/resize.cc b/src/operators/resize.cc new file mode 100644 index 00000000..b7b2e3bc --- /dev/null +++ b/src/operators/resize.cc @@ -0,0 +1,248 @@ +#include "operators/resize.h" +#include +namespace infini { +ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor sizes, + EKeepAspectRatioPolicy ratioPolicy, + ENearestMode nearestMode, + ECoordinateTransMode coordTransMode) + : OperatorObj(OpType::Resize, {input, nullptr, nullptr, sizes}, {output}), + coMode(coordTransMode), mode(ECoeffMode::nearest), + nearestMode(nearestMode), ratioPolicy(ratioPolicy) { + if (coordTransMode == ECoordinateTransMode::tfCropAndResize) + IT_TODO_HALT(); + InitBySizes(input, sizes, axes); + + IT_ASSERT(checkValid(graph)); +} + +ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor scales, + ENearestMode nearestMode, + ECoordinateTransMode coordTransMode) + : OperatorObj(OpType::Resize, {input, nullptr, scales, nullptr}, {output}), + coMode(coordTransMode), mode(ECoeffMode::nearest), + nearestMode(nearestMode) { + InitByScales(input, scales, axes); + + IT_ASSERT(checkValid(graph)); +} + +ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor sizes, + EKeepAspectRatioPolicy ratioPolicy, ECoeffMode mode, + ECoordinateTransMode coordTransMode) + : OperatorObj(OpType::Resize, {input, nullptr, nullptr, sizes}, {output}), + coMode(coordTransMode), mode(mode), ratioPolicy(ratioPolicy) { + if (coordTransMode == ECoordinateTransMode::tfCropAndResize) + IT_TODO_HALT(); + InitBySizes(input, sizes, axes); + + IT_ASSERT(checkValid(graph)); +} + +ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output, + const std::optional> &axes, Tensor scales, + ECoeffMode mode, ECoordinateTransMode coordTransMode) + : OperatorObj(OpType::Resize, {input, nullptr, scales, nullptr}, {output}), + coMode(coordTransMode), mode(mode) { + if (coordTransMode == ECoordinateTransMode::tfCropAndResize) + IT_TODO_HALT(); + InitByScales(input, scales, axes); + + IT_ASSERT(checkValid(graph)); +} + +void ResizeObj::InitBySizes(Tensor input, Tensor sizes, + const std::optional> &axes) { + IT_ASSERT(sizes != nullptr); + size_t size = sizes->getDims()[0]; + IT_ASSERT(size == input->getDims().size() || + (axes != std::nullopt && size == (*axes).size())); + + if (axes == std::nullopt) + for (size_t i = 0; i < input->getDims().size(); ++i) + this->axes.emplace_back(i); + else + // check axes + for (size_t i = 0; i < (*axes).size(); ++i) { + auto val = (*axes)[i]; + if (val < 0) + IT_TODO_HALT(); + IT_ASSERT((size_t)val < inputs[0]->getDims().size()); + this->axes.emplace_back(val); + } + + // init this->scales + for (size_t i = 0; i < input->getDims().size(); ++i) { + this->scales.emplace_back(1); + } + + // copy sizes data to host. + IT_ASSERT(sizes->getDataBlob() != nullptr); + Runtime runtime = CpuRuntimeObj::getInstance(); + int *data = (int *)runtime->alloc(sizes->getBytes()); + sizes->getRuntime()->copyBlobToCPU( + (void *)data, sizes->getRawDataPtr(), sizes->getBytes()); + + auto inDims = input->getDims(); + int n = this->axes.size(); + switch (ratioPolicy) { + case EKeepAspectRatioPolicy::stretch: + for (int i = 0; i < n; ++i) + scales[this->axes[i]] = + (float)data[i] / (float)inDims[this->axes[i]]; + break; + case EKeepAspectRatioPolicy::notLarger: { + float scale = (float)data[0] / (float)inDims[this->axes[0]]; + for (int i = 1; i < n; ++i) { + auto tmp = (float)data[i] / (float)inDims[this->axes[i]]; + scale = scale < tmp ? scale : tmp; + } + for (int i = 0; i < n; ++i) + scales[this->axes[i]] = scale; + break; + } + case EKeepAspectRatioPolicy::notSmaller: { + float scale = (float)data[0] / (float)inDims[this->axes[0]]; + for (int i = 1; i < n; ++i) { + auto tmp = (float)data[i] / (float)inDims[this->axes[i]]; + scale = scale > tmp ? scale : tmp; + } + for (int i = 0; i < n; ++i) + scales[this->axes[i]] = scale; + break; + } + default: + IT_ASSERT(0); + } + + runtime->dealloc(data); +} + +void ResizeObj::InitByScales(Tensor input, Tensor scales, + const std::optional> &axes) { + IT_ASSERT(scales != nullptr); + size_t size = scales->getDims()[0]; + IT_ASSERT(size == input->getDims().size() || + (axes != std::nullopt && size == (*axes).size())); + + // copy scales data to host. + IT_ASSERT(scales->getDataBlob() != nullptr); + Runtime runtime = CpuRuntimeObj::getInstance(); + float *data = (float *)runtime->alloc(scales->getBytes()); + scales->getRuntime()->copyBlobToCPU( + (void *)data, scales->getRawDataPtr(), scales->getBytes()); + + // init this->scales + for (size_t i = 0; i < input->getDims().size(); ++i) { + this->scales.emplace_back(1); + } + + if (axes == std::nullopt) + for (size_t i = 0; i < input->getDims().size(); ++i) { + this->axes.emplace_back(i); + IT_ASSERT(data[i] > 0); + this->scales[i] = data[i]; + } + else + // check axes + for (size_t i = 0; i < (*axes).size(); ++i) { + auto val = (*axes)[i]; + if (val < 0) + IT_TODO_HALT(); + IT_ASSERT((size_t)val < inputs[0]->getDims().size()); + this->axes.emplace_back(val); + IT_ASSERT(data[i] > 0); + this->scales[val] = data[i]; + } + + runtime->dealloc(data); +} + +vector ResizeObj::inferDataType(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 4); + auto roi = inputs[1]; + auto scales = inputs[2]; + auto sizes = inputs[3]; + IT_ASSERT(roi == nullptr || roi->getDType() == DataType::Float32); + IT_ASSERT(scales == nullptr || scales->getDType() == DataType::Float32); + IT_ASSERT(sizes == nullptr || sizes->getDType() == DataType::UInt32); + return {inputs[0]->getDType()}; +} + +bool ResizeObj::checkCoordinateTransValid(int resizedX, int origiX) const { + if (ECoordinateTransMode::alignCorners == coMode) { + return (!(resizedX <= 1 && origiX != resizedX)); + } + return true; +} + +float ResizeObj::round_int(float x) const { + return (x > 0.0) ? floor(x + 0.5) : ceil(x - 0.5); +} + +// output shape is related to sizes/scales value. +optional> ResizeObj::inferShape(const TensorVec &inputs) const { + auto inDims = inputs[0]->getDims(); + Shape ret = inDims; + int nDim = inDims.size(); + for (int i = 0; i < nDim; ++i) { + int size = round_int(scales[i] * inDims[i]); + IT_ASSERT(checkCoordinateTransValid(size, inDims[i])); + ret[i] = size; + } + + return {{ret}}; +} + +std::string ResizeObj::toString() const { + std::ostringstream os; + os << "Resize" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + if (inputs[1] != nullptr) + os << "roi=" << vecToString(inputs[1]->getDims()) << ","; + if (inputs[2] != nullptr) + os << "scales=" << vecToString(inputs[2]->getDims()) << ","; + if (inputs[3] != nullptr) + os << "sizes=" << vecToString(inputs[3]->getDims()) << ","; + os << "axes=" << vecToString(axes) << ","; + os << "coMode=" << enum_to_underlying(coMode) << ","; + os << "nearestMode=" << enum_to_underlying(nearestMode) << ","; + os << "ratioPolicy=" << enum_to_underlying(ratioPolicy) << ","; + + os << "input=" << inputs[0]->getGuid() << ","; + if (inputs[1] != nullptr) + os << inputs[1]->getGuid() << ","; + if (inputs[2] != nullptr) + os << inputs[2]->getGuid() << ","; + if (inputs[3] != nullptr) + os << inputs[3]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector ResizeObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + for (size_t i = 0; i < outputs[0]->getDims().size(); ++i) + ret.emplace_back(outputs[0]->getDims()[i]); + // ratioPolicy only effects output shape, so did not need + // here. + ret.emplace_back(enum_to_underlying(coMode)); + ret.emplace_back(enum_to_underlying(nearestMode)); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +vector ResizeObj::getOpAttrVector() const { + vector ret = axes; + ret.emplace_back(enum_to_underlying(coMode)); + ret.emplace_back(enum_to_underlying(nearestMode)); + ret.emplace_back(enum_to_underlying(ratioPolicy)); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_resize.cc b/test/kernels/cuda/test_cuda_resize.cc new file mode 100644 index 00000000..d0c2aff7 --- /dev/null +++ b/test/kernels/cuda/test_cuda_resize.cc @@ -0,0 +1,370 @@ +#include "cmath" +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/resize.h" +#include "test.h" +namespace infini { +TEST(Resize, Cuda_downsample_sizes_nearest) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); + auto sizes = gCpu->addTensor({4}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8}); + sizes->copyData(vector{1, 1, 1, 3}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, std::nullopt, + gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{1, 2, 4})); +} + +TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto sizes = gCpu->addTensor({2}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4}); + sizes->copyData(vector{7, 8}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, vector{2, 3}, + gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::notLarger, + ResizeObj::ENearestMode::roundPreferFloor, + ResizeObj::ECoordinateTransMode::halfPixel); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData( + vector{1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, + 4, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 4, 4, 4})); +} + +TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto sizes = gCpu->addTensor({2}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4}); + sizes->copyData(vector{7, 8}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = + gCuda->addOp(gCuda->cloneTensor(input), nullptr, + vector{2, 3}, gCuda->cloneTensor(sizes), + ResizeObj::EKeepAspectRatioPolicy::notSmaller, + ResizeObj::ENearestMode::roundPreferFloor, + ResizeObj::ECoordinateTransMode::halfPixel); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{ + 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, + 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, + 4, 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4})); +} + +TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); + auto sizes = gCpu->addTensor({4}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData( + vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + sizes->copyData(vector{1, 1, 8, 8}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, std::nullopt, + gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch, + ResizeObj::ENearestMode::ceil, + ResizeObj::ECoordinateTransMode::halfPixel); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto o = op->getOutput(0); + // //cudaPrintTensor(o); + auto oCpu = gCpu->cloneTensor(o); + EXPECT_TRUE(oCpu->equalData(vector{ + 1, 2, 2, 3, 3, 4, 4, 4, 5, 6, 6, 7, 7, 8, 8, 8, + 5, 6, 6, 7, 7, 8, 8, 8, 9, 10, 10, 11, 11, 12, 12, 12, + 9, 10, 10, 11, 11, 12, 12, 12, 13, 14, 14, 15, 15, 16, 16, 16, + 13, 14, 14, 15, 15, 16, 16, 16, 13, 14, 14, 15, 15, 16, 16, 16})); +} + +TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); + auto sizes = gCpu->addTensor({2}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData( + vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + sizes->copyData(vector{8, 8}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, vector{3, 2}, + gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch, + ResizeObj::ENearestMode::floor, + ResizeObj::ECoordinateTransMode::alignCorners); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto o = op->getOutput(0); + // cudaPrintTensor(o); + auto oCpu = gCpu->cloneTensor(o); + EXPECT_TRUE(oCpu->equalData(vector{ + 1, 1, 1, 2, 2, 3, 3, 4, 1, 1, 1, 2, 2, 3, 3, 4, + 1, 1, 1, 2, 2, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7, 8, + 5, 5, 5, 6, 6, 7, 7, 8, 9, 9, 9, 10, 10, 11, 11, 12, + 9, 9, 9, 10, 10, 11, 11, 12, 13, 13, 13, 14, 14, 15, 15, 16})); +} + +TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); + auto sizes = gCpu->addTensor({4}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData( + vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + sizes->copyData(vector{1, 1, 8, 8}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, std::nullopt, + gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch, + ResizeObj::ENearestMode::roundPreferCeil, + ResizeObj::ECoordinateTransMode::asymmetric); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto o = op->getOutput(0); + // cudaPrintTensor(o); + auto oCpu = gCpu->cloneTensor(o); + EXPECT_TRUE(oCpu->equalData(vector{ + 1, 2, 2, 3, 3, 4, 4, 4, 5, 6, 6, 7, 7, 8, 8, 8, + 5, 6, 6, 7, 7, 8, 8, 8, 9, 10, 10, 11, 11, 12, 12, 12, + 9, 10, 10, 11, 11, 12, 12, 12, 13, 14, 14, 15, 15, 16, 16, 16, + 13, 14, 14, 15, 15, 16, 16, 16, 13, 14, 14, 15, 15, 16, 16, 16})); +} + +TEST(Resize, Cuda_downsample_scales_nearest) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8}); + scales->copyData(vector{1, 1, 0.6, 0.6}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp(gCuda->cloneTensor(input), nullptr, + std::nullopt, gCuda->cloneTensor(scales)); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{1, 3})); +} + +TEST(Resize, Cuda_upsample_scales_nearest) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4}); + scales->copyData(vector{1, 1, 2, 3}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp(gCuda->cloneTensor(input), nullptr, + std::nullopt, gCuda->cloneTensor(scales)); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE( + oCpu->equalData(vector{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, + 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4})); +} + +TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto scales = gCpu->addTensor({2}, DataType::Float32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4}); + scales->copyData(vector{3, 2}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = + gCuda->addOp(gCuda->cloneTensor(input), nullptr, + vector{3, 2}, gCuda->cloneTensor(scales)); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE( + oCpu->equalData(vector{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, + 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4})); +} + +TEST(Resize, Cuda_downsample_scales_linear) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8}); + scales->copyData(vector{1, 1, 0.6, 0.6}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp(gCuda->cloneTensor(input), nullptr, + std::nullopt, gCuda->cloneTensor(scales), + ResizeObj::ECoeffMode::linear); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{2.6666665, 4.3333331})); +} + +TEST(Resize, Cuda_upsample_scales_linear) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4}); + scales->copyData(vector{1, 1, 2, 2}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp(gCuda->cloneTensor(input), nullptr, + std::nullopt, gCuda->cloneTensor(scales), + ResizeObj::ECoeffMode::linear); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE( + oCpu->equalData(vector{1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, + 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4})); +} + +TEST(Resize, Cuda_upsample_scales_linear_align_corners) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32); + auto scales = gCpu->addTensor({4}, DataType::Float32); + gCpu->dataMalloc(); + input->copyData(vector{1, 2, 3, 4}); + scales->copyData(vector{1, 1, 2, 2}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, std::nullopt, + gCuda->cloneTensor(scales), ResizeObj::ECoeffMode::linear, + ResizeObj::ECoordinateTransMode::alignCorners); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + cudaPrintTensor(op->getOutput(0)); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{ + 1, 1.333333, 1.666667, 2, 1.666667, 2, 2.333333, 2.666667, 2.333333, + 2.6666667, 3, 3.333333, 3, 3.333333, 3.6666667, 4})); +} + +TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32); + auto sizes = gCpu->addTensor({4}, DataType::UInt32); + gCpu->dataMalloc(); + input->copyData( + vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + sizes->copyData(vector{1, 1, 3, 1}); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto op = gCuda->addOp( + gCuda->cloneTensor(input), nullptr, std::nullopt, + gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch, + ResizeObj::ECoeffMode::linear, + ResizeObj::ECoordinateTransMode::pytorchHalfPixel); + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput(0)); + // cudaPrintTensor(op->getOutput(0)); + EXPECT_TRUE(oCpu->equalData(vector{1.666667, 7, 12.33333})); +} + +} // namespace infini diff --git a/test/operators/test_resize.cc b/test/operators/test_resize.cc new file mode 100644 index 00000000..94e28cb6 --- /dev/null +++ b/test/operators/test_resize.cc @@ -0,0 +1,79 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/resize.h" +#include "test.h" + +namespace infini { +TEST(Resize, ShapeInference) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + // downsample_sizes_nearest no axes + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32); + Tensor sizes = g->addTensor({4}, DataType::UInt32); + sizes->dataMalloc(); + sizes->copyData(vector{1, 1, 1, 3}); + auto op = + g->addOp(i, nullptr, std::nullopt, sizes, + ResizeObj::EKeepAspectRatioPolicy::stretch); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 3})); + } + // upsample_sizes_nearest with axes + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32); + Tensor sizes = g->addTensor({2}, DataType::UInt32); + sizes->dataMalloc(); + sizes->copyData(vector{1, 3}); + auto op = + g->addOp(i, nullptr, vector{2, 3}, sizes, + ResizeObj::EKeepAspectRatioPolicy::stretch); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 3})); + } + // upsample_sizes_nearest_notlarger + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32); + Tensor sizes = g->addTensor({2}, DataType::UInt32); + sizes->dataMalloc(); + sizes->copyData(vector{7, 8}); + auto op = + g->addOp(i, nullptr, vector{2, 3}, sizes, + ResizeObj::EKeepAspectRatioPolicy::notLarger); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 4, 8})); + } + // upsample_sizes_nearest_notsmaller + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32); + Tensor sizes = g->addTensor({3}, DataType::UInt32); + sizes->dataMalloc(); + sizes->copyData(vector{2, 6, 8}); + auto op = + g->addOp(i, nullptr, vector{1, 2, 3}, sizes, + ResizeObj::EKeepAspectRatioPolicy::notSmaller); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 9, 6, 12})); + } + // downsample_scales + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 1, 4, 4}, DataType::UInt32); + Tensor scales = g->addTensor({3}, DataType::Float32); + scales->dataMalloc(); + scales->copyData(vector{1, 0.8, 0.8}); + auto op = g->addOp(i, nullptr, vector{1, 2, 3}, scales); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 3, 3})); + } + // upsample_scales + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 1, 2, 2}, DataType::UInt32); + Tensor scales = g->addTensor({4}, DataType::Float32); + scales->dataMalloc(); + scales->copyData(vector{1, 1, 2, 2}); + auto op = g->addOp(i, nullptr, std::nullopt, scales); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 4, 4})); + } +} + +} // namespace infini