forked from jiuyuan/InfiniTensor
Add: resize operator and cuda kernel,support nearest/linear coef. (#51)
ADD: resize operator and cuda kernel,support nearest/linear coef. fix some fix tests add more tests for linear mode. add linear coef mode. add scales add tests fix tests. add notLarger notSmaller fix add test ADD:resize operator and cuda kernel
This commit is contained in:
parent
63d8aff985
commit
c5966f8d81
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "cuda/cuda_common.h"
|
||||
|
||||
typedef struct {
|
||||
int nDims;
|
||||
int oDims[4];
|
||||
int inDims[4];
|
||||
int inStride[4];
|
||||
float scale[4];
|
||||
} MetaData;
|
||||
|
||||
namespace infini {
|
||||
void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode, int nearestMode);
|
||||
void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode);
|
||||
} // namespace infini
|
|
@ -0,0 +1,83 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class ResizeObj : public OperatorObj {
|
||||
public:
|
||||
enum class ECoordinateTransMode {
|
||||
halfPixel,
|
||||
pytorchHalfPixel,
|
||||
alignCorners,
|
||||
asymmetric,
|
||||
tfCropAndResize
|
||||
};
|
||||
enum class ENearestMode { roundPreferFloor, roundPreferCeil, floor, ceil };
|
||||
enum class EKeepAspectRatioPolicy { stretch, notLarger, notSmaller };
|
||||
enum class ECoeffMode { nearest, linear, cubic };
|
||||
|
||||
private:
|
||||
vector<int> axes;
|
||||
vector<float> scales;
|
||||
ECoordinateTransMode coMode; // compute src coordinate from dst coordinate
|
||||
ECoeffMode mode; // coeff mode,for computing dst value from coordinate src
|
||||
// neighborhood .
|
||||
ENearestMode nearestMode; // used in "nearest" mode, indicates how to get
|
||||
// "nearest" pixel
|
||||
EKeepAspectRatioPolicy
|
||||
ratioPolicy; // used for computing shape when using "sizes"
|
||||
|
||||
public:
|
||||
// nearest mode, not tf_crop_and_resize
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy,
|
||||
ENearestMode nearestMode = ENearestMode::roundPreferFloor,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales,
|
||||
ENearestMode nearestMode = ENearestMode::roundPreferFloor,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
|
||||
// linear mode
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy, ECoeffMode mode,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
ResizeObj(
|
||||
GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales, ECoeffMode mode,
|
||||
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 4; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
ECoeffMode getMode() const { return mode; }
|
||||
int getNearestMode() const { return enum_to_underlying(nearestMode); }
|
||||
int getKeepAxesRatioPolicy() const {
|
||||
return enum_to_underlying(ratioPolicy);
|
||||
}
|
||||
int getCoordinateTransMode() const { return enum_to_underlying(coMode); }
|
||||
float getScale(int i) const {
|
||||
IT_ASSERT((size_t)i < scales.size());
|
||||
return scales[i];
|
||||
}
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
float round_int(float x) const;
|
||||
bool checkCoordinateTransValid(int resizedCo, int origiCo) const;
|
||||
void InitBySizes(Tensor input, Tensor sizes,
|
||||
const std::optional<vector<int>> &axes);
|
||||
void InitByScales(Tensor input, Tensor sizes,
|
||||
const std::optional<vector<int>> &axes);
|
||||
};
|
||||
} // namespace infini
|
|
@ -0,0 +1,47 @@
|
|||
#include "operators/resize.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/resize.cuh"
|
||||
namespace infini {
|
||||
class ResizeCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ResizeObj>(_op);
|
||||
auto in = op->getInputs(0);
|
||||
auto out = op->getOutputs()[0];
|
||||
|
||||
int nDims = in->getDims().size();
|
||||
if (nDims > 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
MetaData metaData;
|
||||
memset(&metaData, 0, sizeof(metaData));
|
||||
metaData.nDims = nDims;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
metaData.inDims[i] = in->getDims()[i];
|
||||
metaData.oDims[i] = out->getDims()[i];
|
||||
metaData.inStride[i] = in->getStride()[i];
|
||||
metaData.scale[i] = op->getScale(i);
|
||||
}
|
||||
|
||||
switch (op->getMode()) {
|
||||
case ResizeObj::ECoeffMode::nearest:
|
||||
resize_kernel_nearest(in->getRawDataPtr<float *>(),
|
||||
out->getRawDataPtr<float *>(), metaData,
|
||||
out->size(), op->getCoordinateTransMode(),
|
||||
op->getNearestMode());
|
||||
break;
|
||||
case ResizeObj::ECoeffMode::linear:
|
||||
resize_kernel_linear(in->getRawDataPtr<float *>(),
|
||||
out->getRawDataPtr<float *>(), metaData,
|
||||
out->size(), op->getCoordinateTransMode());
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Resize, DataType::Float32, ResizeCuda,
|
||||
"Resize_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,208 @@
|
|||
#include "cmath"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/resize.cuh"
|
||||
#include <functional>
|
||||
|
||||
#ifndef GPU_LAMBDA
|
||||
#define GPU_LAMBDA __device__
|
||||
#endif
|
||||
|
||||
// nearest mode
|
||||
__device__ int round_prefer_ceil(float x) {
|
||||
return (x > 0.0) ? floor(x + 0.5) : ceil(x - 0.5);
|
||||
}
|
||||
|
||||
__device__ int round_prefer_floor(float x) {
|
||||
return (x > 0.0) ? floor(x + 0.4) : ceil(x - 0.4);
|
||||
}
|
||||
|
||||
__device__ int prefer_floor(float x) { return std::floor(x); }
|
||||
|
||||
__device__ int prefer_ceil(float x) { return std::ceil(x); }
|
||||
|
||||
// coordinate transform mode
|
||||
__device__ float half_pixel(int idx, float scale, int, int) {
|
||||
return (idx + 0.5) / scale - 0.5;
|
||||
}
|
||||
|
||||
__device__ float pytorch_half_pixel(int idx, float scale, int length_resized,
|
||||
int) {
|
||||
return length_resized > 1 ? (idx + 0.5) / scale - 0.5 : 0;
|
||||
}
|
||||
|
||||
__device__ float align_corners(int idx, float scale, int length_resized,
|
||||
int length_original) {
|
||||
if (length_resized == 1)
|
||||
return 0;
|
||||
return (float)idx * (float)(length_original - 1) /
|
||||
(float)(length_resized - 1);
|
||||
}
|
||||
|
||||
__device__ float asymmetric(int idx, float scale, int length_resized,
|
||||
int length_original) {
|
||||
return idx / scale;
|
||||
}
|
||||
/*
|
||||
__device__ float tf_crop_and_resize(int idx, float scale, int length_resized,
|
||||
int length_original) {
|
||||
|
||||
}*/
|
||||
|
||||
// ATTENTION:The order of device functions in array must be consistent with the
|
||||
// order in the enums of ResizeObj.
|
||||
using nearest_mod_func_t = int (*)(float);
|
||||
__device__ nearest_mod_func_t p_nearest_mode_fun[] = {
|
||||
round_prefer_floor, round_prefer_ceil, prefer_floor, prefer_ceil};
|
||||
|
||||
using coordinate_trans_mod_func_t = float (*)(int idxO, float scale, int lenO,
|
||||
int lenR);
|
||||
__device__ coordinate_trans_mod_func_t p_cooridnate_trans_mode_func[] = {
|
||||
half_pixel, pytorch_half_pixel, align_corners, asymmetric};
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__device__ int nearestCoordinateTrans(int dOffset, MetaData metaData,
|
||||
T1 transModeFun, T2 nearestModeFun) {
|
||||
int sOffset = 0;
|
||||
for (int i = metaData.nDims - 1; i >= 0; --i) {
|
||||
int dIdx = dOffset % metaData.oDims[i];
|
||||
dOffset = dOffset / metaData.oDims[i];
|
||||
|
||||
if (metaData.inDims[i] == metaData.oDims[i])
|
||||
sOffset += dIdx * metaData.inStride[i];
|
||||
else {
|
||||
float scale = (float)metaData.oDims[i] / (float)metaData.inDims[i];
|
||||
int sIdx = nearestModeFun(transModeFun(
|
||||
dIdx, scale, metaData.oDims[i], metaData.inDims[i]));
|
||||
if (sIdx > metaData.inDims[i] - 1)
|
||||
sIdx = metaData.inDims[i] - 1;
|
||||
else if (sIdx < 0)
|
||||
sIdx = 0;
|
||||
sOffset += sIdx * metaData.inStride[i];
|
||||
}
|
||||
}
|
||||
return sOffset;
|
||||
}
|
||||
|
||||
__global__ void _resize_kernel_nearest(float *in, float *out, MetaData metaData,
|
||||
size_t num, int coordinateMode,
|
||||
int nearestMode) {
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto stride = blockDim.x * gridDim.x;
|
||||
|
||||
while (tid < num) {
|
||||
int offset = nearestCoordinateTrans(
|
||||
tid, metaData, p_cooridnate_trans_mode_func[coordinateMode],
|
||||
p_nearest_mode_fun[nearestMode]);
|
||||
out[tid] = in[offset];
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
// ATTENTION: Make sure dim <=4
|
||||
typedef struct {
|
||||
int offset[16];
|
||||
float power[16];
|
||||
} NeighborList;
|
||||
|
||||
int __device__ getLimitIdx(int idx, int limit) {
|
||||
if (idx < 0)
|
||||
return 0;
|
||||
if (idx > limit)
|
||||
return limit;
|
||||
return idx;
|
||||
}
|
||||
|
||||
__global__ void _resize_kernel_linear(float *in, float *out, MetaData metaData,
|
||||
size_t num, int coordinateMode) {
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto stride = blockDim.x * gridDim.x;
|
||||
|
||||
while (tid < num) {
|
||||
auto dOffset = tid;
|
||||
auto neighborNum = 0;
|
||||
NeighborList neighborList;
|
||||
memset(&neighborList, 0, sizeof(neighborList));
|
||||
for (int i = metaData.nDims - 1; i >= 0; --i) {
|
||||
int dIdx = dOffset % metaData.oDims[i];
|
||||
float scale = metaData.scale[i];
|
||||
float sIdx = p_cooridnate_trans_mode_func[coordinateMode](
|
||||
dIdx, scale, scale * metaData.inDims[i], metaData.inDims[i]);
|
||||
|
||||
int idx = std::floor(sIdx);
|
||||
float power = 1 - (sIdx - idx);
|
||||
|
||||
// update neighborList
|
||||
if (metaData.inDims[i] == 1) {
|
||||
if (neighborNum == 0) {
|
||||
neighborList.offset[0] = 0;
|
||||
neighborList.power[0] = power;
|
||||
neighborNum = 1;
|
||||
} else {
|
||||
for (int j = 0; j < neighborNum; j++) {
|
||||
neighborList.power[j] *= power;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (neighborNum == 0) {
|
||||
neighborList.offset[0] =
|
||||
getLimitIdx(idx, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[0] = power;
|
||||
neighborList.offset[1] =
|
||||
getLimitIdx(idx + 1, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[1] = 1 - power;
|
||||
neighborNum = 2;
|
||||
} else {
|
||||
for (int j = 0; j < neighborNum; j++) {
|
||||
neighborList.offset[j + neighborNum] =
|
||||
neighborList.offset[j] +
|
||||
getLimitIdx(idx + 1, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[j + neighborNum] =
|
||||
(neighborList.power[j]) * (1 - power);
|
||||
|
||||
neighborList.offset[j] +=
|
||||
getLimitIdx(idx, metaData.inDims[i] - 1) *
|
||||
metaData.inStride[i];
|
||||
neighborList.power[j] *= power;
|
||||
}
|
||||
neighborNum *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
dOffset = dOffset / metaData.oDims[i];
|
||||
}
|
||||
|
||||
float val = 0;
|
||||
for (int i = 0; i < neighborNum; ++i) {
|
||||
val += in[neighborList.offset[i]] * neighborList.power[i];
|
||||
}
|
||||
out[tid] = val;
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode, int nearestMode) {
|
||||
int blocksize = 32 * 16;
|
||||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
IT_ASSERT(nearestMode <
|
||||
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
|
||||
_resize_kernel_nearest<<<blocksize, gridsize>>>(
|
||||
in, out, metaData, num, coordinateMode, nearestMode);
|
||||
}
|
||||
|
||||
void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
||||
size_t num, int coordinateMode) {
|
||||
int blocksize = 32 * 16;
|
||||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_linear<<<blocksize, gridsize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
} // namespace infini
|
|
@ -69,6 +69,7 @@ std::string GatherObj::toString() const {
|
|||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> GatherObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
#include "operators/resize.h"
|
||||
#include <cmath>
|
||||
namespace infini {
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy,
|
||||
ENearestMode nearestMode,
|
||||
ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, nullptr, sizes}, {output}),
|
||||
coMode(coordTransMode), mode(ECoeffMode::nearest),
|
||||
nearestMode(nearestMode), ratioPolicy(ratioPolicy) {
|
||||
if (coordTransMode == ECoordinateTransMode::tfCropAndResize)
|
||||
IT_TODO_HALT();
|
||||
InitBySizes(input, sizes, axes);
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales,
|
||||
ENearestMode nearestMode,
|
||||
ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, scales, nullptr}, {output}),
|
||||
coMode(coordTransMode), mode(ECoeffMode::nearest),
|
||||
nearestMode(nearestMode) {
|
||||
InitByScales(input, scales, axes);
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor sizes,
|
||||
EKeepAspectRatioPolicy ratioPolicy, ECoeffMode mode,
|
||||
ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, nullptr, sizes}, {output}),
|
||||
coMode(coordTransMode), mode(mode), ratioPolicy(ratioPolicy) {
|
||||
if (coordTransMode == ECoordinateTransMode::tfCropAndResize)
|
||||
IT_TODO_HALT();
|
||||
InitBySizes(input, sizes, axes);
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ResizeObj::ResizeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const std::optional<vector<int>> &axes, Tensor scales,
|
||||
ECoeffMode mode, ECoordinateTransMode coordTransMode)
|
||||
: OperatorObj(OpType::Resize, {input, nullptr, scales, nullptr}, {output}),
|
||||
coMode(coordTransMode), mode(mode) {
|
||||
if (coordTransMode == ECoordinateTransMode::tfCropAndResize)
|
||||
IT_TODO_HALT();
|
||||
InitByScales(input, scales, axes);
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
||||
const std::optional<vector<int>> &axes) {
|
||||
IT_ASSERT(sizes != nullptr);
|
||||
size_t size = sizes->getDims()[0];
|
||||
IT_ASSERT(size == input->getDims().size() ||
|
||||
(axes != std::nullopt && size == (*axes).size()));
|
||||
|
||||
if (axes == std::nullopt)
|
||||
for (size_t i = 0; i < input->getDims().size(); ++i)
|
||||
this->axes.emplace_back(i);
|
||||
else
|
||||
// check axes
|
||||
for (size_t i = 0; i < (*axes).size(); ++i) {
|
||||
auto val = (*axes)[i];
|
||||
if (val < 0)
|
||||
IT_TODO_HALT();
|
||||
IT_ASSERT((size_t)val < inputs[0]->getDims().size());
|
||||
this->axes.emplace_back(val);
|
||||
}
|
||||
|
||||
// init this->scales
|
||||
for (size_t i = 0; i < input->getDims().size(); ++i) {
|
||||
this->scales.emplace_back(1);
|
||||
}
|
||||
|
||||
// copy sizes data to host.
|
||||
IT_ASSERT(sizes->getDataBlob() != nullptr);
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
int *data = (int *)runtime->alloc(sizes->getBytes());
|
||||
sizes->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, sizes->getRawDataPtr<void *>(), sizes->getBytes());
|
||||
|
||||
auto inDims = input->getDims();
|
||||
int n = this->axes.size();
|
||||
switch (ratioPolicy) {
|
||||
case EKeepAspectRatioPolicy::stretch:
|
||||
for (int i = 0; i < n; ++i)
|
||||
scales[this->axes[i]] =
|
||||
(float)data[i] / (float)inDims[this->axes[i]];
|
||||
break;
|
||||
case EKeepAspectRatioPolicy::notLarger: {
|
||||
float scale = (float)data[0] / (float)inDims[this->axes[0]];
|
||||
for (int i = 1; i < n; ++i) {
|
||||
auto tmp = (float)data[i] / (float)inDims[this->axes[i]];
|
||||
scale = scale < tmp ? scale : tmp;
|
||||
}
|
||||
for (int i = 0; i < n; ++i)
|
||||
scales[this->axes[i]] = scale;
|
||||
break;
|
||||
}
|
||||
case EKeepAspectRatioPolicy::notSmaller: {
|
||||
float scale = (float)data[0] / (float)inDims[this->axes[0]];
|
||||
for (int i = 1; i < n; ++i) {
|
||||
auto tmp = (float)data[i] / (float)inDims[this->axes[i]];
|
||||
scale = scale > tmp ? scale : tmp;
|
||||
}
|
||||
for (int i = 0; i < n; ++i)
|
||||
scales[this->axes[i]] = scale;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
IT_ASSERT(0);
|
||||
}
|
||||
|
||||
runtime->dealloc(data);
|
||||
}
|
||||
|
||||
void ResizeObj::InitByScales(Tensor input, Tensor scales,
|
||||
const std::optional<vector<int>> &axes) {
|
||||
IT_ASSERT(scales != nullptr);
|
||||
size_t size = scales->getDims()[0];
|
||||
IT_ASSERT(size == input->getDims().size() ||
|
||||
(axes != std::nullopt && size == (*axes).size()));
|
||||
|
||||
// copy scales data to host.
|
||||
IT_ASSERT(scales->getDataBlob() != nullptr);
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
float *data = (float *)runtime->alloc(scales->getBytes());
|
||||
scales->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, scales->getRawDataPtr<void *>(), scales->getBytes());
|
||||
|
||||
// init this->scales
|
||||
for (size_t i = 0; i < input->getDims().size(); ++i) {
|
||||
this->scales.emplace_back(1);
|
||||
}
|
||||
|
||||
if (axes == std::nullopt)
|
||||
for (size_t i = 0; i < input->getDims().size(); ++i) {
|
||||
this->axes.emplace_back(i);
|
||||
IT_ASSERT(data[i] > 0);
|
||||
this->scales[i] = data[i];
|
||||
}
|
||||
else
|
||||
// check axes
|
||||
for (size_t i = 0; i < (*axes).size(); ++i) {
|
||||
auto val = (*axes)[i];
|
||||
if (val < 0)
|
||||
IT_TODO_HALT();
|
||||
IT_ASSERT((size_t)val < inputs[0]->getDims().size());
|
||||
this->axes.emplace_back(val);
|
||||
IT_ASSERT(data[i] > 0);
|
||||
this->scales[val] = data[i];
|
||||
}
|
||||
|
||||
runtime->dealloc(data);
|
||||
}
|
||||
|
||||
vector<DataType> ResizeObj::inferDataType(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 4);
|
||||
auto roi = inputs[1];
|
||||
auto scales = inputs[2];
|
||||
auto sizes = inputs[3];
|
||||
IT_ASSERT(roi == nullptr || roi->getDType() == DataType::Float32);
|
||||
IT_ASSERT(scales == nullptr || scales->getDType() == DataType::Float32);
|
||||
IT_ASSERT(sizes == nullptr || sizes->getDType() == DataType::UInt32);
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
bool ResizeObj::checkCoordinateTransValid(int resizedX, int origiX) const {
|
||||
if (ECoordinateTransMode::alignCorners == coMode) {
|
||||
return (!(resizedX <= 1 && origiX != resizedX));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
float ResizeObj::round_int(float x) const {
|
||||
return (x > 0.0) ? floor(x + 0.5) : ceil(x - 0.5);
|
||||
}
|
||||
|
||||
// output shape is related to sizes/scales value.
|
||||
optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) const {
|
||||
auto inDims = inputs[0]->getDims();
|
||||
Shape ret = inDims;
|
||||
int nDim = inDims.size();
|
||||
for (int i = 0; i < nDim; ++i) {
|
||||
int size = round_int(scales[i] * inDims[i]);
|
||||
IT_ASSERT(checkCoordinateTransValid(size, inDims[i]));
|
||||
ret[i] = size;
|
||||
}
|
||||
|
||||
return {{ret}};
|
||||
}
|
||||
|
||||
std::string ResizeObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Resize"
|
||||
<< "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
if (inputs[1] != nullptr)
|
||||
os << "roi=" << vecToString(inputs[1]->getDims()) << ",";
|
||||
if (inputs[2] != nullptr)
|
||||
os << "scales=" << vecToString(inputs[2]->getDims()) << ",";
|
||||
if (inputs[3] != nullptr)
|
||||
os << "sizes=" << vecToString(inputs[3]->getDims()) << ",";
|
||||
os << "axes=" << vecToString(axes) << ",";
|
||||
os << "coMode=" << enum_to_underlying(coMode) << ",";
|
||||
os << "nearestMode=" << enum_to_underlying(nearestMode) << ",";
|
||||
os << "ratioPolicy=" << enum_to_underlying(ratioPolicy) << ",";
|
||||
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
if (inputs[1] != nullptr)
|
||||
os << inputs[1]->getGuid() << ",";
|
||||
if (inputs[2] != nullptr)
|
||||
os << inputs[2]->getGuid() << ",";
|
||||
if (inputs[3] != nullptr)
|
||||
os << inputs[3]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> ResizeObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
for (size_t i = 0; i < outputs[0]->getDims().size(); ++i)
|
||||
ret.emplace_back(outputs[0]->getDims()[i]);
|
||||
// ratioPolicy only effects output shape, so did not need
|
||||
// here.
|
||||
ret.emplace_back(enum_to_underlying(coMode));
|
||||
ret.emplace_back(enum_to_underlying(nearestMode));
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> ResizeObj::getOpAttrVector() const {
|
||||
vector<int> ret = axes;
|
||||
ret.emplace_back(enum_to_underlying(coMode));
|
||||
ret.emplace_back(enum_to_underlying(nearestMode));
|
||||
ret.emplace_back(enum_to_underlying(ratioPolicy));
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,370 @@
|
|||
#include "cmath"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/resize.h"
|
||||
#include "test.h"
|
||||
namespace infini {
|
||||
TEST(Resize, Cuda_downsample_sizes_nearest) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 1, 3});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, vector<int>{2, 3},
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::notLarger,
|
||||
ResizeObj::ENearestMode::roundPreferFloor,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(
|
||||
vector<float>{1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1,
|
||||
1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4,
|
||||
4, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 4, 4, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
vector<int>{2, 3}, gCuda->cloneTensor(sizes),
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller,
|
||||
ResizeObj::ENearestMode::roundPreferFloor,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2,
|
||||
2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3,
|
||||
4, 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 8, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::ceil,
|
||||
ResizeObj::ECoordinateTransMode::halfPixel);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto o = op->getOutput(0);
|
||||
// //cudaPrintTensor(o);
|
||||
auto oCpu = gCpu->cloneTensor(o);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 2, 2, 3, 3, 4, 4, 4, 5, 6, 6, 7, 7, 8, 8, 8,
|
||||
5, 6, 6, 7, 7, 8, 8, 8, 9, 10, 10, 11, 11, 12, 12, 12,
|
||||
9, 10, 10, 11, 11, 12, 12, 12, 13, 14, 14, 15, 15, 16, 16, 16,
|
||||
13, 14, 14, 15, 15, 16, 16, 16, 13, 14, 14, 15, 15, 16, 16, 16}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{8, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, vector<int>{3, 2},
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::floor,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto o = op->getOutput(0);
|
||||
// cudaPrintTensor(o);
|
||||
auto oCpu = gCpu->cloneTensor(o);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 1, 1, 2, 2, 3, 3, 4, 1, 1, 1, 2, 2, 3, 3, 4,
|
||||
1, 1, 1, 2, 2, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7, 8,
|
||||
5, 5, 5, 6, 6, 7, 7, 8, 9, 9, 9, 10, 10, 11, 11, 12,
|
||||
9, 9, 9, 10, 10, 11, 11, 12, 13, 13, 13, 14, 14, 15, 15, 16}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 8, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::roundPreferCeil,
|
||||
ResizeObj::ECoordinateTransMode::asymmetric);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto o = op->getOutput(0);
|
||||
// cudaPrintTensor(o);
|
||||
auto oCpu = gCpu->cloneTensor(o);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 2, 2, 3, 3, 4, 4, 4, 5, 6, 6, 7, 7, 8, 8, 8,
|
||||
5, 6, 6, 7, 7, 8, 8, 8, 9, 10, 10, 11, 11, 12, 12, 12,
|
||||
9, 10, 10, 11, 11, 12, 12, 12, 13, 14, 14, 15, 15, 16, 16, 16,
|
||||
13, 14, 14, 15, 15, 16, 16, 16, 13, 14, 14, 15, 15, 16, 16, 16}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_nearest) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyData(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales));
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 3}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_nearest) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{1, 1, 2, 3});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales));
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
|
||||
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({2}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{3, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op =
|
||||
gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
vector<int>{3, 2}, gCuda->cloneTensor(scales));
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
|
||||
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_linear) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyData(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales),
|
||||
ResizeObj::ECoeffMode::linear);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{2.6666665, 4.3333331}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_linear) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(gCuda->cloneTensor(input), nullptr,
|
||||
std::nullopt, gCuda->cloneTensor(scales),
|
||||
ResizeObj::ECoeffMode::linear);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5,
|
||||
2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_linear_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(scales), ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::ECoordinateTransMode::alignCorners);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
cudaPrintTensor(op->getOutput(0));
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1, 1.333333, 1.666667, 2, 1.666667, 2, 2.333333, 2.666667, 2.333333,
|
||||
2.6666667, 3, 3.333333, 3, 3.333333, 3.6666667, 4}));
|
||||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 3, 1});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto op = gCuda->addOp<ResizeObj>(
|
||||
gCuda->cloneTensor(input), nullptr, std::nullopt,
|
||||
gCuda->cloneTensor(sizes), ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ECoeffMode::linear,
|
||||
ResizeObj::ECoordinateTransMode::pytorchHalfPixel);
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput(0));
|
||||
// cudaPrintTensor(op->getOutput(0));
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{1.666667, 7, 12.33333}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,79 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/resize.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(Resize, ShapeInference) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
// downsample_sizes_nearest no axes
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({4}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 1, 3});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, std::nullopt, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 3}));
|
||||
}
|
||||
// upsample_sizes_nearest with axes
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({2}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{1, 3});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, vector<int>{2, 3}, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 3}));
|
||||
}
|
||||
// upsample_sizes_nearest_notlarger
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({2}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, vector<int>{2, 3}, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notLarger);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 4, 8}));
|
||||
}
|
||||
// upsample_sizes_nearest_notsmaller
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({3}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{2, 6, 8});
|
||||
auto op =
|
||||
g->addOp<ResizeObj>(i, nullptr, vector<int>{1, 2, 3}, sizes,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 9, 6, 12}));
|
||||
}
|
||||
// downsample_scales
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 1, 4, 4}, DataType::UInt32);
|
||||
Tensor scales = g->addTensor({3}, DataType::Float32);
|
||||
scales->dataMalloc();
|
||||
scales->copyData(vector<float>{1, 0.8, 0.8});
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, vector<int>{1, 2, 3}, scales);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 3, 3}));
|
||||
}
|
||||
// upsample_scales
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 1, 2, 2}, DataType::UInt32);
|
||||
Tensor scales = g->addTensor({4}, DataType::Float32);
|
||||
scales->dataMalloc();
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, std::nullopt, scales);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 4, 4}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue