forked from jiuyuan/InfiniTensor
ADD: reconfig ResizeObj, support "tf_crop_and_resize " and cubic coeff kernel. (#59)
add cubic coef add tf_crop_and_resize
This commit is contained in:
parent
c5966f8d81
commit
d780f687fc
|
@ -7,6 +7,8 @@ typedef struct {
|
|||
int inDims[4];
|
||||
int inStride[4];
|
||||
float scale[4];
|
||||
float roiS[4];
|
||||
float roiE[4];
|
||||
} MetaData;
|
||||
|
||||
namespace infini {
|
||||
|
@ -14,4 +16,6 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
|||
size_t num, int coordinateMode, int nearestMode);
|
||||
void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode);
|
||||
void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode);
|
||||
} // namespace infini
|
||||
|
|
|
@ -12,13 +12,21 @@ class ResizeObj : public OperatorObj {
|
|||
asymmetric,
|
||||
tfCropAndResize
|
||||
};
|
||||
enum class ENearestMode { roundPreferFloor, roundPreferCeil, floor, ceil };
|
||||
enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller };
|
||||
enum class ENearestMode {
|
||||
roundPreferFloor,
|
||||
roundPreferCeil,
|
||||
floor,
|
||||
ceil,
|
||||
none
|
||||
};
|
||||
enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller, none };
|
||||
enum class ECoeffMode { nearest, linear, cubic };
|
||||
|
||||
private:
|
||||
vector<int> axes;
|
||||
vector<float> scales;
|
||||
vector<float> roi;
|
||||
|
||||
ECoordinateTransMode coMode; // compute src coordinate from dst coordinate
|
||||
ECoeffMode mode; // coeff mode,for computing dst value from coordinate src
|
||||
// neighborhood .
|
||||
|
@ -28,34 +36,27 @@ class ResizeObj : public OperatorObj {
|
|||
ratioPolicy; // used for computing shape when using "sizes"
|
||||
|
||||
public:
|
||||
// nearest mode, not tf_crop_and_resize
|
||||
// nearest mode
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy,
|
||||
ENearestMode nearestMode = ENearestMode::roundPreferFloor,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes, Tensor scales,
|
||||
Tensor roi,
|
||||
EKeepAspectRatioPolicy ratioPolicy = EKeepAspectRatioPolicy::none,
|
||||
ENearestMode nearestMode = ENearestMode::roundPreferFloor,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
|
||||
// linear mode
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy, ECoeffMode mode,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales, ECoeffMode mode,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes, Tensor scales,
|
||||
Tensor roi, ECoeffMode mode,
|
||||
EKeepAspectRatioPolicy ratioPolicy = EKeepAspectRatioPolicy::none,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
|
||||
// Operator clone(TensorVec inputs, TensorVec outputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 4; }
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
ECoeffMode getMode() const { return mode; }
|
||||
|
@ -66,7 +67,17 @@ class ResizeObj : public OperatorObj {
|
|||
int getCoordinateTransMode() const { return enum_to_underlying(coMode); }
|
||||
float getScale(int i) const {
|
||||
IT_ASSERT((size_t)i < scales.size());
|
||||
return scales[i];
|
||||
return scales.at(i);
|
||||
}
|
||||
float getRoi(int i) const {
|
||||
if (coMode == ECoordinateTransMode::tfCropAndResize) {
|
||||
IT_ASSERT(size_t(i) < roi.size());
|
||||
return roi.at(i);
|
||||
} else
|
||||
return 0;
|
||||
}
|
||||
bool isResizeBySizes() const {
|
||||
return ratioPolicy != EKeepAspectRatioPolicy::none;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -74,7 +85,8 @@ class ResizeObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
float round_int(float x) const;
|
||||
bool checkCoordinateTransValid(int resizedCo, int origiCo) const;
|
||||
void init(const Tensor &input, const Tensor &sizes, const Tensor &scales,
|
||||
const Tensor &roi, const std::optional<vector<int>> &axes);
|
||||
void InitBySizes(Tensor input, Tensor sizes,
|
||||
const std::optional<vector<int>> &axes);
|
||||
void InitByScales(Tensor input, Tensor sizes,
|
||||
|
|
|
@ -21,6 +21,8 @@ class ResizeCuda : public CudaKernelWithoutConfig {
|
|||
metaData.oDims[i] = out->getDims()[i];
|
||||
metaData.inStride[i] = in->getStride()[i];
|
||||
metaData.scale[i] = op->getScale(i);
|
||||
metaData.roiS[i] = op->getRoi(i);
|
||||
metaData.roiE[i] = op->getRoi(i + nDims);
|
||||
}
|
||||
|
||||
switch (op->getMode()) {
|
||||
|
@ -35,6 +37,11 @@ class ResizeCuda : public CudaKernelWithoutConfig {
|
|||
out->getRawDataPtr<float *>(), metaData,
|
||||
out->size(), op->getCoordinateTransMode());
|
||||
break;
|
||||
case ResizeObj::ECoeffMode::cubic:
|
||||
resize_kernel_cubic(in->getRawDataPtr<float *>(),
|
||||
out->getRawDataPtr<float *>(), metaData,
|
||||
out->size(), op->getCoordinateTransMode());
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
|
|
@ -3,10 +3,6 @@
|
|||
#include "cuda/resize.cuh"
|
||||
#include <functional>
|
||||
|
||||
#ifndef GPU_LAMBDA
|
||||
#define GPU_LAMBDA __device__
|
||||
#endif
|
||||
|
||||
// nearest mode
|
||||
__device__ int round_prefer_ceil(float x) {
|
||||
return (x > 0.0) ? floor(x + 0.5) : ceil(x - 0.5);
|
||||
|
@ -21,32 +17,36 @@ __device__ int prefer_floor(float x) { return std::floor(x); }
|
|||
__device__ int prefer_ceil(float x) { return std::ceil(x); }
|
||||
|
||||
// coordinate transform mode
|
||||
__device__ float half_pixel(int idx, float scale, int, int) {
|
||||
return (idx + 0.5) / scale - 0.5;
|
||||
__device__ float half_pixel(int idx, MetaData metaData, int dim) {
|
||||
return (idx + 0.5) / metaData.scale[dim] - 0.5;
|
||||
}
|
||||
|
||||
__device__ float pytorch_half_pixel(int idx, float scale, int length_resized,
|
||||
int) {
|
||||
return length_resized > 1 ? (idx + 0.5) / scale - 0.5 : 0;
|
||||
__device__ float pytorch_half_pixel(int idx, MetaData metaData, int dim) {
|
||||
float resizedLen = metaData.scale[dim] * metaData.inDims[dim];
|
||||
return resizedLen > 1 ? (idx + 0.5) / metaData.scale[dim] - 0.5 : 0;
|
||||
}
|
||||
|
||||
__device__ float align_corners(int idx, float scale, int length_resized,
|
||||
int length_original) {
|
||||
if (length_resized == 1)
|
||||
__device__ float align_corners(int idx, MetaData metaData, int dim) {
|
||||
float resizedLen = metaData.scale[dim] * metaData.inDims[dim];
|
||||
if (resizedLen == 1)
|
||||
return 0;
|
||||
return (float)idx * (float)(length_original - 1) /
|
||||
(float)(length_resized - 1);
|
||||
return (float)idx * (float)(metaData.inDims[dim] - 1) /
|
||||
(float)(resizedLen - 1);
|
||||
}
|
||||
|
||||
__device__ float asymmetric(int idx, float scale, int length_resized,
|
||||
int length_original) {
|
||||
return idx / scale;
|
||||
__device__ float asymmetric(int idx, MetaData metaData, int dim) {
|
||||
return idx / metaData.scale[dim];
|
||||
}
|
||||
/*
|
||||
__device__ float tf_crop_and_resize(int idx, float scale, int length_resized,
|
||||
int length_original) {
|
||||
|
||||
}*/
|
||||
__device__ float tf_crop_and_resize(int idx, MetaData metaData, int dim) {
|
||||
int resizedLen = metaData.scale[dim] * metaData.inDims[dim];
|
||||
return resizedLen > 1
|
||||
? metaData.roiS[dim] * (metaData.inDims[dim] - 1) +
|
||||
idx * (metaData.roiE[dim] - metaData.roiS[dim]) *
|
||||
(metaData.inDims[dim] - 1) / (resizedLen - 1)
|
||||
: 0.5 * (metaData.roiS[dim] + metaData.roiE[dim]) *
|
||||
(metaData.inDims[dim] - 1);
|
||||
}
|
||||
|
||||
// ATTENTION:The order of device functions in array must be consistent with the
|
||||
// order in the enums of ResizeObj.
|
||||
|
@ -54,10 +54,11 @@ using nearest_mod_func_t = int (*)(float);
|
|||
__device__ nearest_mod_func_t p_nearest_mode_fun[] = {
|
||||
round_prefer_floor, round_prefer_ceil, prefer_floor, prefer_ceil};
|
||||
|
||||
using coordinate_trans_mod_func_t = float (*)(int idxO, float scale, int lenO,
|
||||
int lenR);
|
||||
using coordinate_trans_mod_func_t = float (*)(int idxO, MetaData metaData,
|
||||
int dim);
|
||||
__device__ coordinate_trans_mod_func_t p_cooridnate_trans_mode_func[] = {
|
||||
half_pixel, pytorch_half_pixel, align_corners, asymmetric};
|
||||
half_pixel, pytorch_half_pixel, align_corners, asymmetric,
|
||||
tf_crop_and_resize};
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__device__ int nearestCoordinateTrans(int dOffset, MetaData metaData,
|
||||
|
@ -70,9 +71,8 @@ __device__ int nearestCoordinateTrans(int dOffset, MetaData metaData,
|
|||
if (metaData.inDims[i] == metaData.oDims[i])
|
||||
sOffset += dIdx * metaData.inStride[i];
|
||||
else {
|
||||
float scale = (float)metaData.oDims[i] / (float)metaData.inDims[i];
|
||||
int sIdx = nearestModeFun(transModeFun(
|
||||
dIdx, scale, metaData.oDims[i], metaData.inDims[i]));
|
||||
int sIdx = nearestModeFun(transModeFun(dIdx, metaData, i));
|
||||
|
||||
if (sIdx > metaData.inDims[i] - 1)
|
||||
sIdx = metaData.inDims[i] - 1;
|
||||
else if (sIdx < 0)
|
||||
|
@ -98,12 +98,6 @@ __global__ void _resize_kernel_nearest(float *in, float *out, MetaData metaData,
|
|||
}
|
||||
}
|
||||
|
||||
// ATTENTION: Make sure dim <=4
|
||||
typedef struct {
|
||||
int offset[16];
|
||||
float power[16];
|
||||
} NeighborList;
|
||||
|
||||
int __device__ getLimitIdx(int idx, int limit) {
|
||||
if (idx < 0)
|
||||
return 0;
|
||||
|
@ -112,77 +106,104 @@ int __device__ getLimitIdx(int idx, int limit) {
|
|||
return idx;
|
||||
}
|
||||
|
||||
__global__ void _resize_kernel_linear(float *in, float *out, MetaData metaData,
|
||||
size_t num, int coordinateMode) {
|
||||
template <int N>
|
||||
__device__ void getEvenNeighbors(float idx, int limit, int *neighbors) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
neighbors[i] = getLimitIdx(std::floor(idx) - N / 2 + 1 + i, limit);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void getLinearCoef(float ratio, float *coeffs) {
|
||||
coeffs[0] = 1 - ratio;
|
||||
coeffs[1] = ratio;
|
||||
}
|
||||
|
||||
__device__ void getCubicCoef(float ratio, float *coeffs) {
|
||||
float A = -0.75;
|
||||
coeffs[0] =
|
||||
((A * (ratio + 1) - 5 * A) * (ratio + 1) + 8 * A) * (ratio + 1) - 4 * A;
|
||||
coeffs[1] = ((A + 2) * ratio - (A + 3)) * ratio * ratio + 1;
|
||||
coeffs[2] =
|
||||
((A + 2) * (1 - ratio) - (A + 3)) * (1 - ratio) * (1 - ratio) + 1;
|
||||
coeffs[3] = ((A * ((1 - ratio) + 1) - 5 * A) * ((1 - ratio) + 1) + 8 * A) *
|
||||
((1 - ratio) + 1) -
|
||||
4 * A;
|
||||
}
|
||||
|
||||
using get_coef_func_t = void (*)(float, float *);
|
||||
|
||||
// N is neighbor number at each dim
|
||||
template <int N, int totalNeighborNum>
|
||||
__device__ void _resize_kernel_coeff(float *in, float *out, MetaData metaData,
|
||||
size_t num,
|
||||
coordinate_trans_mod_func_t coTransFunc,
|
||||
get_coef_func_t getCoefFunc) {
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto stride = blockDim.x * gridDim.x;
|
||||
|
||||
while (tid < num) {
|
||||
auto dOffset = tid;
|
||||
auto neighborNum = 0;
|
||||
NeighborList neighborList;
|
||||
memset(&neighborList, 0, sizeof(neighborList));
|
||||
auto neighborCnt = 1;
|
||||
int offsetList[totalNeighborNum], offsetListOld[totalNeighborNum];
|
||||
float powerList[totalNeighborNum], powerListOld[totalNeighborNum];
|
||||
|
||||
for (size_t i = 0; i < totalNeighborNum; ++i) {
|
||||
offsetList[i] = 0;
|
||||
powerList[i] = 1;
|
||||
}
|
||||
for (int i = metaData.nDims - 1; i >= 0; --i) {
|
||||
int dIdx = dOffset % metaData.oDims[i];
|
||||
float scale = metaData.scale[i];
|
||||
float sIdx = p_cooridnate_trans_mode_func[coordinateMode](
|
||||
dIdx, scale, scale * metaData.inDims[i], metaData.inDims[i]);
|
||||
float sIdx = coTransFunc(dIdx, metaData, i);
|
||||
|
||||
int idx = std::floor(sIdx);
|
||||
float power = 1 - (sIdx - idx);
|
||||
float power[N];
|
||||
int neighbors[N];
|
||||
getCoefFunc(sIdx - idx, power);
|
||||
getEvenNeighbors<N>(sIdx, metaData.inDims[i] - 1, neighbors);
|
||||
|
||||
// update neighborList
|
||||
if (metaData.inDims[i] == 1) {
|
||||
if (neighborNum == 0) {
|
||||
neighborList.offset[0] = 0;
|
||||
neighborList.power[0] = power;
|
||||
neighborNum = 1;
|
||||
} else {
|
||||
for (int j = 0; j < neighborNum; j++) {
|
||||
neighborList.power[j] *= power;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (neighborNum == 0) {
|
||||
neighborList.offset[0] =
|
||||
getLimitIdx(idx, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[0] = power;
|
||||
neighborList.offset[1] =
|
||||
getLimitIdx(idx + 1, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[1] = 1 - power;
|
||||
neighborNum = 2;
|
||||
} else {
|
||||
for (int j = 0; j < neighborNum; j++) {
|
||||
neighborList.offset[j + neighborNum] =
|
||||
neighborList.offset[j] +
|
||||
getLimitIdx(idx + 1, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[j + neighborNum] =
|
||||
(neighborList.power[j]) * (1 - power);
|
||||
for (int n = 0; n < neighborCnt; ++n) {
|
||||
offsetListOld[n] = offsetList[n];
|
||||
powerListOld[n] = powerList[n];
|
||||
}
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int idx = 0; idx < neighborCnt; ++idx) {
|
||||
offsetList[idx + n * neighborCnt] =
|
||||
offsetListOld[idx] +
|
||||
neighbors[n] * metaData.inStride[i];
|
||||
|
||||
neighborList.offset[j] +=
|
||||
getLimitIdx(idx, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[j] *= power;
|
||||
}
|
||||
neighborNum *= 2;
|
||||
powerList[idx + n * neighborCnt] =
|
||||
powerListOld[idx] * power[n];
|
||||
}
|
||||
}
|
||||
|
||||
neighborCnt = neighborCnt * N;
|
||||
dOffset = dOffset / metaData.oDims[i];
|
||||
}
|
||||
|
||||
float val = 0;
|
||||
for (int i = 0; i < neighborNum; ++i) {
|
||||
val += in[neighborList.offset[i]] * neighborList.power[i];
|
||||
for (int i = 0; i < neighborCnt; ++i) {
|
||||
val += in[offsetList[i]] * powerList[i];
|
||||
}
|
||||
out[tid] = val;
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _resize_kernel_linear_coeff(float *in, float *out,
|
||||
MetaData metaData, size_t num,
|
||||
int coordinateMode) {
|
||||
_resize_kernel_coeff<2, 16>(in, out, metaData, num,
|
||||
p_cooridnate_trans_mode_func[coordinateMode],
|
||||
getLinearCoef);
|
||||
}
|
||||
|
||||
__global__ void _resize_kernel_cubic_coeff(float *in, float *out,
|
||||
MetaData metaData, size_t num,
|
||||
int coordinateMode) {
|
||||
_resize_kernel_coeff<4, 256>(in, out, metaData, num,
|
||||
p_cooridnate_trans_mode_func[coordinateMode],
|
||||
getCubicCoef);
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode, int nearestMode) {
|
||||
|
@ -202,7 +223,17 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_linear<<<blocksize, gridsize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
_resize_kernel_linear_coeff<<<blocksize, gridsize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
|
||||
void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode) {
|
||||
int blocksize = 32 * 16;
|
||||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_cubic_coeff<<<blocksize, gridsize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,55 +3,92 @@
|
|||
namespace infini {
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
Tensor scales, Tensor roi,
|
||||
EKeepAspectRatioPolicy ratioPolicy,
|
||||
ENearestMode nearestMode,
|
||||
ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, nullptr, sizes}, {output}),
|
||||
coMode(coordTransMode), mode(ECoeffMode::nearest),
|
||||
nearestMode(nearestMode), ratioPolicy(ratioPolicy) {
|
||||
if (coordTransMode == ECoordinateTransMode::tfCropAndResize)
|
||||
IT_TODO_HALT();
|
||||
InitBySizes(input, sizes, axes);
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales,
|
||||
ENearestMode nearestMode,
|
||||
ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, scales, nullptr}, {output}),
|
||||
coMode(coordTransMode), mode(ECoeffMode::nearest),
|
||||
nearestMode(nearestMode) {
|
||||
InitByScales(input, scales, axes);
|
||||
|
||||
: OperatorObj(OpType::Resize, {input}, {output}), coMode(coordTransMode),
|
||||
mode(ECoeffMode::nearest), nearestMode(nearestMode),
|
||||
ratioPolicy(ratioPolicy) {
|
||||
init(input, sizes, scales, roi, axes);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy, ECoeffMode mode,
|
||||
Tensor scales, Tensor roi, ECoeffMode mode,
|
||||
EKeepAspectRatioPolicy ratioPolicy,
|
||||
ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, nullptr, sizes}, {output}),
|
||||
coMode(coordTransMode), mode(mode), ratioPolicy(ratioPolicy) {
|
||||
if (coordTransMode == ECoordinateTransMode::tfCropAndResize)
|
||||
IT_TODO_HALT();
|
||||
InitBySizes(input, sizes, axes);
|
||||
|
||||
: OperatorObj(OpType::Resize, {input}, {output}), coMode(coordTransMode),
|
||||
mode(mode), nearestMode(ENearestMode::none), ratioPolicy(ratioPolicy) {
|
||||
init(input, sizes, scales, roi, axes);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales,
|
||||
ECoeffMode mode, ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, scales, nullptr}, {output}),
|
||||
coMode(coordTransMode), mode(mode) {
|
||||
if (coordTransMode == ECoordinateTransMode::tfCropAndResize)
|
||||
IT_TODO_HALT();
|
||||
InitByScales(input, scales, axes);
|
||||
void ResizeObj::init(const Tensor &input, const Tensor &sizes,
|
||||
const Tensor &scales, const Tensor &roi,
|
||||
const std::optional<vector<int>> &axes) {
|
||||
IT_ASSERT(!(nullptr != sizes && nullptr != scales));
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
// inputs of operator must not be nullptr, due to the check in
|
||||
// OperatorObj::OperatorObj
|
||||
if (nullptr != sizes) {
|
||||
IT_ASSERT(isResizeBySizes());
|
||||
inputs.push_back(sizes);
|
||||
InitBySizes(input, sizes, axes);
|
||||
} else if (nullptr != scales) {
|
||||
inputs.push_back(scales);
|
||||
InitByScales(input, scales, axes);
|
||||
}
|
||||
|
||||
// roi
|
||||
if (ECoordinateTransMode::tfCropAndResize == coMode) {
|
||||
IT_ASSERT(nullptr != roi);
|
||||
inputs.push_back(roi);
|
||||
IT_ASSERT(roi->getDims().size() == 1);
|
||||
IT_ASSERT((size_t)roi->getDims()[0] == this->axes.size() * 2);
|
||||
|
||||
// init roi_start = 0;roi_end =1
|
||||
size_t nDims = input->getDims().size();
|
||||
for (size_t i = 0; i < nDims; ++i) {
|
||||
this->roi.emplace_back(0);
|
||||
}
|
||||
for (size_t i = 0; i < nDims; ++i) {
|
||||
this->roi.emplace_back(1);
|
||||
}
|
||||
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
std::shared_ptr<float> dataObj((float *)runtime->alloc(roi->getBytes()),
|
||||
[&](float *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
roi->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, roi->getRawDataPtr<void *>(), roi->getBytes());
|
||||
|
||||
for (size_t i = 0; i < this->axes.size(); ++i) {
|
||||
this->roi[this->axes[i]] = data[i];
|
||||
this->roi[this->axes[i] + nDims] = data[i + this->axes.size()];
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
Operator ResizeObj::clone(TensorVec inputs, TensorVec outputs) {
|
||||
Tensor roi{nullptr}, sizes{nullptr}, scales{nullptr};
|
||||
if (inputs.size() == 3)
|
||||
roi = inputs[2];
|
||||
if (isResizeBySizes())
|
||||
sizes = inputs[1];
|
||||
else
|
||||
scales = inputs[1];
|
||||
|
||||
if (mode == ECoeffMode::nearest)
|
||||
return make_ref<ResizeObj>(nullptr, inputs[0], outputs[0], axes,
|
||||
inputs[1], nullptr, roi, ratioPolicy,
|
||||
nearestMode, coMode);
|
||||
else
|
||||
return make_ref<ResizeObj>(nullptr, inputs[0], outputs[0], axes,
|
||||
inputs[1], nullptr, roi, mode, ratioPolicy,
|
||||
coMode);
|
||||
}*/
|
||||
|
||||
void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
||||
const std::optional<vector<int>> &axes) {
|
||||
|
@ -81,7 +118,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
|||
// copy sizes data to host.
|
||||
IT_ASSERT(sizes->getDataBlob() != nullptr);
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
int *data = (int *)runtime->alloc(sizes->getBytes());
|
||||
std::shared_ptr<int> dataObj((int *)runtime->alloc(sizes->getBytes()),
|
||||
[&](int *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
sizes->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, sizes->getRawDataPtr<void *>(), sizes->getBytes());
|
||||
|
||||
|
@ -116,8 +155,6 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
|||
default:
|
||||
IT_ASSERT(0);
|
||||
}
|
||||
|
||||
runtime->dealloc(data);
|
||||
}
|
||||
|
||||
void ResizeObj::InitByScales(Tensor input, Tensor scales,
|
||||
|
@ -130,7 +167,9 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales,
|
|||
// copy scales data to host.
|
||||
IT_ASSERT(scales->getDataBlob() != nullptr);
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
float *data = (float *)runtime->alloc(scales->getBytes());
|
||||
std::shared_ptr<float> dataObj((float *)runtime->alloc(scales->getBytes()),
|
||||
[&](float *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
scales->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, scales->getRawDataPtr<void *>(), scales->getBytes());
|
||||
|
||||
|
@ -156,26 +195,22 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales,
|
|||
IT_ASSERT(data[i] > 0);
|
||||
this->scales[val] = data[i];
|
||||
}
|
||||
|
||||
runtime->dealloc(data);
|
||||
}
|
||||
|
||||
vector<DataType> ResizeObj::inferDataType(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 4);
|
||||
auto roi = inputs[1];
|
||||
auto scales = inputs[2];
|
||||
auto sizes = inputs[3];
|
||||
IT_ASSERT(roi == nullptr || roi->getDType() == DataType::Float32);
|
||||
IT_ASSERT(scales == nullptr || scales->getDType() == DataType::Float32);
|
||||
IT_ASSERT(sizes == nullptr || sizes->getDType() == DataType::UInt32);
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
bool ResizeObj::checkCoordinateTransValid(int resizedX, int origiX) const {
|
||||
if (ECoordinateTransMode::alignCorners == coMode) {
|
||||
return (!(resizedX <= 1 && origiX != resizedX));
|
||||
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
||||
if (inputs.size() == 3) {
|
||||
auto roi = inputs[2];
|
||||
IT_ASSERT(roi && roi->getDType() == DataType::Float32);
|
||||
}
|
||||
return true;
|
||||
if (isResizeBySizes()) {
|
||||
auto sizes = inputs[1];
|
||||
IT_ASSERT(sizes && sizes->getDType() == DataType::UInt32);
|
||||
} else {
|
||||
auto scales = inputs[1];
|
||||
IT_ASSERT(scales && scales->getDType() == DataType::Float32);
|
||||
}
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
float ResizeObj::round_int(float x) const {
|
||||
|
@ -189,7 +224,6 @@ optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) const {
|
|||
int nDim = inDims.size();
|
||||
for (int i = 0; i < nDim; ++i) {
|
||||
int size = round_int(scales[i] * inDims[i]);
|
||||
IT_ASSERT(checkCoordinateTransValid(size, inDims[i]));
|
||||
ret[i] = size;
|
||||
}
|
||||
|
||||
|
@ -202,24 +236,21 @@ std::string ResizeObj::toString() const {
|
|||
<< "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
if (inputs[1] != nullptr)
|
||||
os << "roi=" << vecToString(inputs[1]->getDims()) << ",";
|
||||
if (inputs[2] != nullptr)
|
||||
os << "scales=" << vecToString(inputs[2]->getDims()) << ",";
|
||||
if (inputs[3] != nullptr)
|
||||
os << "sizes=" << vecToString(inputs[3]->getDims()) << ",";
|
||||
if (inputs.size() == 3)
|
||||
os << "roi=" << vecToString(inputs[2]->getDims()) << ",";
|
||||
if (isResizeBySizes())
|
||||
os << "sizes=" << vecToString(inputs[1]->getDims()) << ",";
|
||||
else
|
||||
os << "scales=" << vecToString(inputs[1]->getDims()) << ",";
|
||||
os << "axes=" << vecToString(axes) << ",";
|
||||
os << "coMode=" << enum_to_underlying(coMode) << ",";
|
||||
os << "nearestMode=" << enum_to_underlying(nearestMode) << ",";
|
||||
os << "ratioPolicy=" << enum_to_underlying(ratioPolicy) << ",";
|
||||
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
if (inputs[1] != nullptr)
|
||||
os << inputs[1]->getGuid() << ",";
|
||||
if (inputs[2] != nullptr)
|
||||
os << inputs[1]->getGuid() << ",";
|
||||
if (inputs.size() == 3)
|
||||
os << inputs[2]->getGuid() << ",";
|
||||
if (inputs[3] != nullptr)
|
||||
os << inputs[3]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,8 @@ TEST(Resize, Cuda_downsample_sizes_nearest) {
|
|||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -45,7 +46,8 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) {
|
|||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, vector<int>{2, 3},
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::notLarger,
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notLarger,
|
||||
ResizeObj::ENearestMode::roundPreferFloor,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
gCuda->dataMalloc();
|
||||
|
@ -72,12 +74,12 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) {
|
|||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
vector<int>{2, 3}, gCuda->cloneTensor(sizes),
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller,
|
||||
ResizeObj::ENearestMode::roundPreferFloor,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, vector<int>{2, 3},
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller,
|
||||
ResizeObj::ENearestMode::roundPreferFloor,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -105,7 +107,8 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) {
|
|||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::ceil,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
gCuda->dataMalloc();
|
||||
|
@ -113,7 +116,6 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) {
|
|||
|
||||
// copy output from CUDA to CPU
|
||||
auto o = op->getOutput(0);
|
||||
// //cudaPrintTensor(o);
|
||||
auto oCpu = gCpu->cloneTensor(o);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 2, 2, 3, 3, 4, 4, 4, 5, 6, 6, 7, 7, 8, 8, 8,
|
||||
|
@ -138,7 +140,8 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) {
|
|||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, vector<int>{3, 2},
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::floor,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
|
@ -146,7 +149,6 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) {
|
|||
|
||||
// copy output from CUDA to CPU
|
||||
auto o = op->getOutput(0);
|
||||
// cudaPrintTensor(o);
|
||||
auto oCpu = gCpu->cloneTensor(o);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 1, 1, 2, 2, 3, 3, 4, 1, 1, 1, 2, 2, 3, 3, 4,
|
||||
|
@ -171,7 +173,8 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) {
|
|||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::roundPreferCeil,
|
||||
ResizeObj::ECoordinateTransMode::asymmetric);
|
||||
gCuda->dataMalloc();
|
||||
|
@ -179,7 +182,6 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) {
|
|||
|
||||
// copy output from CUDA to CPU
|
||||
auto o = op->getOutput(0);
|
||||
// cudaPrintTensor(o);
|
||||
auto oCpu = gCpu->cloneTensor(o);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 2, 2, 3, 3, 4, 4, 4, 5, 6, 6, 7, 7, 8, 8, 8,
|
||||
|
@ -202,7 +204,8 @@ TEST(Resize, Cuda_downsample_scales_nearest) {
|
|||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales));
|
||||
std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -225,7 +228,8 @@ TEST(Resize, Cuda_upsample_scales_nearest) {
|
|||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales));
|
||||
std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -249,9 +253,9 @@ TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) {
|
|||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
vector<int>{3, 2}, gCuda->cloneTensor(scales));
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
vector<int>{3, 2}, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -275,9 +279,9 @@ TEST(Resize, Cuda_downsample_scales_linear) {
|
|||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales),
|
||||
ResizeObj::ECoeffMode::linear);
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::linear);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -286,6 +290,32 @@ TEST(Resize, Cuda_downsample_scales_linear) {
|
|||
EXPECT_TRUE(oCpu->equalData(vector<float>{2.6666665, 4.3333331}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_linear_aligncorners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyData(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::EKeepAspectRatioPolicy::none,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 3.142857}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_linear) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
@ -299,9 +329,9 @@ TEST(Resize, Cuda_upsample_scales_linear) {
|
|||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales),
|
||||
ResizeObj::ECoeffMode::linear);
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::linear);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
|
@ -326,12 +356,13 @@ TEST(Resize, Cuda_upsample_scales_linear_align_corners) {
|
|||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(scales), ResizeObj::ECoeffMode::linear,
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::EKeepAspectRatioPolicy::none,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
cudaPrintTensor(op->getOutput(0));
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
|
@ -355,16 +386,341 @@ TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) {
|
|||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
gCuda->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ECoordinateTransMode::pytorchHalfPixel);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
// cudaPrintTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1.666667, 7, 12.33333}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_tf_crop_and_resize) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
auto roi = gCpu->addTensor({8}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 3, 3});
|
||||
roi->copyData(vector<float>{0, 0, 0.4, 0.6, 1, 1, 0.6, 0.8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), nullptr, gCuda->cloneTensor(roi),
|
||||
ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ECoordinateTransMode::tfCropAndResize);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{7.6000004, 7.9, 8.2, 8.8, 9.1,
|
||||
9.400001, 10, 10.3, 10.6}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
auto roi = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{3, 3});
|
||||
roi->copyData(vector<float>{0.6, 0.4, 0.8, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, vector<int>{3, 2},
|
||||
gCuda->cloneTensor(sizes), nullptr, gCuda->cloneTensor(roi),
|
||||
ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ECoordinateTransMode::tfCropAndResize);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{7.6000004, 7.9, 8.2, 8.8, 9.1,
|
||||
9.400001, 10, 10.3, 10.6}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 0.8, 0.8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::cubic);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(
|
||||
vector<float>{1.47119141, 2.78125, 4.08251953, 6.71142578, 8.02148438,
|
||||
9.32275391, 11.91650391, 13.2265625, 14.52783203}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_cubic_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 0.8, 0.8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::cubic,
|
||||
ResizeObj::EKeepAspectRatioPolicy::none,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(
|
||||
vector<float>{1, 2.39519159, 3.79038317, 6.58076634, 7.97595793,
|
||||
9.37114951, 12.16153268, 13.55672427, 14.95191585}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::cubic);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
0.47265625, 0.76953125, 1.24609375, 1.875, 2.28125,
|
||||
2.91015625, 3.38671875, 3.68359375, 1.66015625, 1.95703125,
|
||||
2.43359375, 3.0625, 3.46875, 4.09765625, 4.57421875,
|
||||
4.87109375, 3.56640625, 3.86328125, 4.33984375, 4.96875,
|
||||
5.375, 6.00390625, 6.48046875, 6.77734375, 6.08203125,
|
||||
6.37890625, 6.85546875, 7.484375, 7.890625, 8.51953125,
|
||||
8.99609375, 9.29296875, 7.70703125, 8.00390625, 8.48046875,
|
||||
9.109375, 9.515625, 10.14453125, 10.62109375, 10.91796875,
|
||||
10.22265625, 10.51953125, 10.99609375, 11.625, 12.03125,
|
||||
12.66015625, 13.13671875, 13.43359375, 12.12890625, 12.42578125,
|
||||
12.90234375, 13.53125, 13.9375, 14.56640625, 15.04296875,
|
||||
15.33984375, 13.31640625, 13.61328125, 14.08984375, 14.71875,
|
||||
15.125, 15.75390625, 16.23046875, 16.52734375}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_cubic_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::cubic,
|
||||
ResizeObj::EKeepAspectRatioPolicy::none,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 1.34110787, 1.80029155, 2.32944606, 2.67055394,
|
||||
3.19970845, 3.65889213, 4, 2.36443149, 2.70553936,
|
||||
3.16472303, 3.69387755, 4.03498542, 4.56413994, 5.02332362,
|
||||
5.36443149, 4.20116618, 4.54227405, 5.00145773, 5.53061224,
|
||||
5.87172012, 6.40087464, 6.86005831, 7.20116618, 6.31778426,
|
||||
6.65889213, 7.1180758, 7.64723032, 7.98833819, 8.51749271,
|
||||
8.97667638, 9.31778426, 7.68221574, 8.02332362, 8.48250729,
|
||||
9.01166181, 9.35276968, 9.8819242, 10.34110787, 10.68221574,
|
||||
9.79883382, 10.13994169, 10.59912536, 11.12827988, 11.46938776,
|
||||
11.99854227, 12.45772595, 12.79883382, 11.63556851, 11.97667638,
|
||||
12.43586006, 12.96501458, 13.30612245, 13.83527697, 14.29446064,
|
||||
14.63556851, 13, 13.34110787, 13.80029155, 14.32944606,
|
||||
14.67055394, 15.19970845, 15.65889213, 16.}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_cubic_asymmetric) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt, nullptr,
|
||||
gCuda->cloneTensor(scales), nullptr, ResizeObj::ECoeffMode::cubic,
|
||||
ResizeObj::EKeepAspectRatioPolicy::none,
|
||||
ResizeObj::ECoordinateTransMode::asymmetric);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1., 1.40625, 2., 2.5, 3., 3.59375, 4., 4.09375,
|
||||
2.625, 3.03125, 3.625, 4.125, 4.625, 5.21875, 5.625, 5.71875,
|
||||
5., 5.40625, 6., 6.5, 7., 7.59375, 8., 8.09375,
|
||||
7., 7.40625, 8., 8.5, 9., 9.59375, 10., 10.09375,
|
||||
9., 9.40625, 10., 10.5, 11., 11.59375, 12., 12.09375,
|
||||
11.375, 11.78125, 12.375, 12.875, 13.375, 13.96875, 14.375, 14.46875,
|
||||
13., 13.40625, 14., 14.5, 15., 15.59375, 16., 16.09375,
|
||||
13.375, 13.78125, 14.375, 14.875, 15.375, 15.96875, 16.375, 16.46875}));
|
||||
}
|
||||
|
||||
//
|
||||
TEST(Resize, Cuda_downsample_sizes_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 3, 3});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(sizes),
|
||||
nullptr, nullptr, ResizeObj::ECoeffMode::cubic,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
|
||||
/*The corresponding test's output of ONNX has some bias, which is:
|
||||
{1.63078704, 3.00462963, 4.37847222, 7.12615741, 8.5,
|
||||
9.87384259, 12.62152778, 13.99537037, 15.36921296}
|
||||
(https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize)*/
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1.63078511, 3.00462794, 4.37846994, 7.12615490, 8.50000000, 9.87384224,
|
||||
12.62152576, 13.99537086, 15.36921501}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 9, 10});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(sizes),
|
||||
nullptr, nullptr, ResizeObj::ECoeffMode::cubic,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
0.45508048, 0.64058018, 0.97158027, 1.42258000, 1.90733004,
|
||||
2.22333097, 2.70807934, 3.15908003, 3.49008012, 3.67558002,
|
||||
1.39437866, 1.57987845, 1.91087842, 2.36187792, 2.84662747,
|
||||
3.16262865, 3.64737630, 4.09837723, 4.42937851, 4.61487770,
|
||||
2.95131063, 3.13681102, 3.46781015, 3.91881013, 4.40356016,
|
||||
4.71956062, 5.20430803, 5.65531015, 5.98631001, 6.17181063,
|
||||
5.20525312, 5.39075279, 5.72175217, 6.17275286, 6.65750170,
|
||||
6.97350359, 7.45825005, 7.90925360, 8.24025249, 8.42575359,
|
||||
6.88975096, 7.07525015, 7.40625000, 7.85725021, 8.34200001,
|
||||
8.65800095, 9.14274597, 9.59375000, 9.92474842, 10.11025047,
|
||||
8.57425022, 8.75974846, 9.09074879, 9.54174805, 10.02649689,
|
||||
10.34249973, 10.82724571, 11.27824974, 11.60924721, 11.79474831,
|
||||
10.82819176, 11.01369190, 11.34469223, 11.79569244, 12.28044128,
|
||||
12.59644127, 13.08118820, 13.53219128, 13.86318874, 14.04869366,
|
||||
12.38512325, 12.57062244, 12.90162182, 13.35262108, 13.83737183,
|
||||
14.15337372, 14.63811684, 15.08912182, 15.42011929, 15.60562229,
|
||||
13.32442474, 13.50992107, 13.84092331, 14.29192352, 14.77667332,
|
||||
15.09267426, 15.57741737, 16.02842331, 16.35941887, 16.54491997}));
|
||||
/* The corresponding test's output of ONNX has some bias, which is:
|
||||
0.45507922, 0.64057922, 0.97157922, 1.42257922, 1.90732922,
|
||||
2.22332922, 2.70807922, 3.15907922, 3.49007922, 3.67557922,
|
||||
1.39437963, 1.57987963, 1.91087963, 2.36187963, 2.84662963,
|
||||
3.16262963, 3.64737963, 4.09837963, 4.42937963, 4.61487963,
|
||||
2.95130693, 3.13680693, 3.46780693, 3.91880693, 4.40355693,
|
||||
4.71955693, 5.20430693, 5.65530693, 5.98630693, 6.17180693,
|
||||
5.20525069, 5.39075069, 5.72175069, 6.17275069, 6.65750069,
|
||||
6.97350069, 7.45825069, 7.90925069, 8.24025069, 8.42575069,
|
||||
6.88975, 7.07525, 7.40625, 7.85725, 8.342,
|
||||
8.658, 9.14275, 9.59375, 9.92475, 10.11025,
|
||||
8.57424931, 8.75974931, 9.09074931, 9.54174931, 10.02649931,
|
||||
10.34249931, 10.82724931, 11.27824931, 11.60924931, 11.79474931,
|
||||
10.82819307, 11.01369307, 11.34469307, 11.79569307, 12.28044307,
|
||||
12.59644307, 13.08119307, 13.53219307, 13.86319307, 14.04869307,
|
||||
12.38512037, 12.57062037, 12.90162037, 13.35262037, 13.83737037,
|
||||
14.15337037, 14.63812037, 15.08912037, 15.42012037, 15.60562037,
|
||||
13.32442078, 13.50992078, 13.84092078, 14.29192078, 14.77667078,
|
||||
15.09267078, 15.57742078, 16.02842078, 16.35942078, 16.54492078}*/
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -13,9 +13,9 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor sizes = g->addTensor({4}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 1, 3});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, std::nullopt, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, std::nullopt, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 3}));
|
||||
}
|
||||
// upsample_sizes_nearest with axes
|
||||
|
@ -25,9 +25,9 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor sizes = g->addTensor({2}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{1, 3});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, vector<int>{2, 3}, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, vector<int>{2, 3}, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 3}));
|
||||
}
|
||||
// upsample_sizes_nearest_notlarger
|
||||
|
@ -37,9 +37,9 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor sizes = g->addTensor({2}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, vector<int>{2, 3}, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notLarger);
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, vector<int>{2, 3}, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notLarger);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 4, 8}));
|
||||
}
|
||||
// upsample_sizes_nearest_notsmaller
|
||||
|
@ -49,9 +49,9 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor sizes = g->addTensor({3}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{2, 6, 8});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, vector<int>{1, 2, 3}, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller);
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, vector<int>{1, 2, 3}, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 9, 6, 12}));
|
||||
}
|
||||
// downsample_scales
|
||||
|
@ -61,7 +61,8 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor scales = g->addTensor({3}, DataType::Float32);
|
||||
scales->dataMalloc();
|
||||
scales->copyData(vector<float>{1, 0.8, 0.8});
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, vector<int>{1, 2, 3}, scales);
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, vector<int>{1, 2, 3}, nullptr,
|
||||
scales, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 3, 3}));
|
||||
}
|
||||
// upsample_scales
|
||||
|
@ -71,7 +72,8 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor scales = g->addTensor({4}, DataType::Float32);
|
||||
scales->dataMalloc();
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, std::nullopt, scales);
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, std::nullopt, nullptr, scales,
|
||||
nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 4, 4}));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue