Compare commits

...

8 Commits

Author SHA1 Message Date
Liyan Zheng bef4c422a0 Add: improve conv2dreduce kernel 2022-10-22 13:26:41 +08:00
Liyan Zheng 67c06733e6 Chore: format 2022-10-22 13:24:42 +08:00
Liyan Zheng aa552b5bd2 Add: batch size in test 2022-10-19 15:50:39 +08:00
huangshuhong 36755c3160 convNHWC+pReLU+biasPReLU 2022-10-19 15:50:39 +08:00
Liyan Zheng 74e998e262 Add: use Runtime stream in Copy 2022-10-19 15:50:39 +08:00
Liyan Zheng 7abe7da0e4 Add: CUDA graph for fsrcnn 2022-10-19 15:50:35 +08:00
huangshuhong 133513be34 fsrcnn test 2022-10-19 15:49:22 +08:00
Liyan Zheng 78425c3209 Add: fsrcnn 2022-10-19 15:49:22 +08:00
19 changed files with 756 additions and 23 deletions

View File

@ -40,6 +40,10 @@ enum class OpType {
Resize,
//
MemBound = 300,
//
Conv2dReduce = 400,
ConvTranspose2dReduce,
BiasPReLU
};
using KernelAttrs = std::tuple<Device, OpType, DataType>;
@ -86,6 +90,9 @@ class OpRegistry {
FOP(Abs);
//
FOP(MemBound);
FOP(Conv2dReduce);
FOP(ConvTranspose2dReduce);
FOP(BiasPReLU);
default:
IT_ASSERT(false);
break;

View File

@ -0,0 +1,9 @@
#pragma once
#include "operators/bias2prelu.h"
namespace infini {
void bias2prelu_kernel(float *input, float *bias, float *output, int n, int h,
int w, int c, bool PReLU, float paramPReLU);
}

View File

@ -0,0 +1,16 @@
#pragma once
#include "operators/conv2dreduce.h"
namespace infini {
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
float paramReLU, int n, int h, int w, int f, int r,
int s, int oh, int ow, int ph, int pw, int sh, int sw,
int dh, int dw);
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
bool PReLU, float paramReLU, int n, int h,
int w, int f, int r, int s, int oh, int ow,
int ph, int pw, int sh, int sw, int dh,
int dw);
} // namespace infini

View File

@ -10,13 +10,14 @@ class CudaRuntimeObj : public RuntimeObj {
cublasHandle_t cublas;
CudaPtr workspace;
size_t workspaceSize;
cudaStream_t stream;
public:
CUdevice cuDevice;
CUcontext newContext;
public:
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
CudaRuntimeObj() : RuntimeObj(Device::CUDA), stream(cudaStreamPerThread) {
// Prepare for nvrtc. cuCtxCreate should be called befero others.
// Otherwise it will result in strange failure, such as cuBLAS failed on
// certian inputs.
@ -26,6 +27,8 @@ class CudaRuntimeObj : public RuntimeObj {
checkCudnnError(cudnnCreate(&cudnn));
checkCublasError(cublasCreate(&cublas));
checkCublasError(cublasSetStream(cublas, stream));
checkCudnnError(cudnnSetStream(cudnn, stream));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
@ -53,6 +56,7 @@ class CudaRuntimeObj : public RuntimeObj {
cudnnHandle_t cudnnHandle() const { return cudnn; }
cublasHandle_t cublasHandle() const { return cublas; }
size_t getWorkspaceSize() const { return workspaceSize; }
cudaStream_t getStream() const { return stream; }
CudaPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
return workspace;

View File

@ -0,0 +1,32 @@
#pragma once
#include "core/operator.h"
namespace infini {
class BiasPReLU : public OperatorObj {
protected:
bool PReLU;
float paramPReLU;
int n, h, w, c;
public:
BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
bool PReLU_, float paramPReLU_);
std::string toString() const override { return "Bias2PReLU"; }
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
bool getPReLU() const { return PReLU; }
float getParamReLU() const { return paramPReLU; }
Tensor getBias() const { return inputs[1]; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -67,19 +67,23 @@ class ConvBaseObj : public OperatorObj {
class ConvObj : public ConvBaseObj {
private:
ActType act;
bool NHWC_layout;
public:
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
Tensor bias = nullptr, ActType act = ActType::None);
Tensor bias = nullptr, ActType act = ActType::None,
bool nhwc = false);
// Constructors for setting padding mode
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
int dh = 1, int dw = 1, Tensor bias = nullptr,
ActType act = ActType::None);
ActType act = ActType::None, bool nhwc = false);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }
bool getNHWCLayout() const { return NHWC_layout; }
int getNumGroups() const override { return c / getChannelPerGroup(); }
private:

View File

@ -0,0 +1,59 @@
#pragma once
#include "core/operator.h"
namespace infini {
class Conv2dReduceBase : public OperatorObj {
protected:
int ph, pw;
int sh, sw;
int dh, dw;
int n, h, w, f, r, s; // c has been reduced
bool PReLU;
float paramReLU;
public:
Conv2dReduceBase(OpType opType, Tensor input, Tensor bias, Tensor output,
bool PReLU_, float paramReLU_, int ph_, int pw_,
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
int getDh() const { return dh; }
int getDw() const { return dw; }
int getPh() const { return ph; }
int getPw() const { return pw; }
int getSh() const { return sh; }
int getSw() const { return sw; }
bool getPReLU() const { return PReLU; }
float getParamReLU() const { return paramReLU; }
Tensor getBias() const { return inputs[1]; }
// optional<vector<Shape>> inferShape(const TensorVec &inputs) const
// override;
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
class Conv2dReduce : public Conv2dReduceBase {
public:
Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
bool PReLU_, float paramReLU_, int ph_, int pw_, int sh_ = 1,
int sw_ = 1, int dh_ = 1, int dw_ = 1);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
};
class Conv2dReduceTranspose : public Conv2dReduceBase {
public:
Conv2dReduceTranspose(GraphObj *graph, Tensor input, Tensor bias,
Tensor output, bool PReLU_, float paramReLU_, int ph_,
int pw_, int sh_ = 1, int sw_ = 1, int dh_ = 1,
int dw_ = 1);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
};
} // namespace infini

View File

@ -101,6 +101,7 @@ void RuntimeObj::printProfilingData(double totalTime,
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
t / totalTime * 100, t / opCnt.at(type));
}
printf("Total_perf_time: %.3lf ms\n", totalTime);
}
Blob RuntimeObj::allocBlob(size_t size) {

View File

@ -0,0 +1,24 @@
#include "operators/bias2prelu.h"
#include "cuda/cuda_bias2prelu.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class Bias2PReluCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, const RuntimeObj *_context) const {
auto op = as<BiasPReLU>(_op);
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
float *const bias = (op->getInputs(1)->getRawDataPtr<float *>());
float *const output = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], h = dim[1], w = dim[2], c = dim[3];
bias2prelu_kernel(input, output, bias, n, h, w, c, op->getPReLU(),
op->getParamReLU());
}
};
REGISTER_KERNEL(Device::CUDA, OpType::BiasPReLU, DataType::Float32,
Bias2PReluCuda, "Bias2PReLU_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,34 @@
#include "cuda/cuda_common.h"
__global__ void bias2prelu_kernel_(
float *__restrict__ input,
float *__restrict__ bias,
float *__restrict__ output,
const bool PReLU, const float paramReLU,
const int n, const int h, const int w, const int c)
{
int nid = blockIdx.x, hid = blockIdx.y;
int wid = threadIdx.x, cid = threadIdx.y;
int offset = nid * h * w * c + hid * w * c + wid * c + cid;
float imm = bias[cid] + input[offset];
if (PReLU) {
imm = (imm > 0.0) ? imm : paramReLU * paramReLU;
}
output[offset] = imm;
}
namespace infini {
void bias2prelu_kernel(float *input, float *bias, float *output,
int n, int h, int w, int c, bool PReLU, float paramPReLU)
{
dim3 grid(n, h);
dim3 block(w, c);
cudaStream_t stream(cudaStreamPerThread);
bias2prelu_kernel_<<<grid, block, 0, stream>>>(input, bias, output,
PReLU, paramPReLU, n, h, w, c);
}
}

View File

@ -65,6 +65,9 @@ class convCudnn : public Kernel {
const int cpg = op->getChannelPerGroup();
const int g = c / cpg;
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
const auto nhwcLayout = op->getNHWCLayout();
const auto inoutLayout =
nhwcLayout ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
int channelsPerGrp = cpg, channels = c;
@ -72,7 +75,7 @@ class convCudnn : public Kernel {
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
inDesc, inoutLayout, CUDNN_DATA_FLOAT, n, channels, h, w));
// get kernels
cudnnFilterDescriptor_t knDesc;
@ -125,12 +128,17 @@ class convCudnn : public Kernel {
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, outn, outc,
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
checkCudnnError(cudnnSetTensor4dDescriptor(
outDesc, inoutLayout, CUDNN_DATA_FLOAT, outn, outc, outh, outw));
if (nhwcLayout) {
IT_ASSERT((vector{outn, outh, outw, outc}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
} else {
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
}
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc);

View File

@ -0,0 +1,46 @@
#include "operators/conv2dreduce.h"
#include "cuda/cuda_conv2dreduce.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class Conv2dReduceCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, const RuntimeObj *_context) const {
auto op = as<Conv2dReduceBase>(_op);
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
float *const bias = op->getInputs(1)
? (op->getInputs(1)->getRawDataPtr<float *>())
: nullptr;
float *const output = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], h = dim[1], w = dim[2], f = dim[3], r = dim[4],
s = dim[5];
int dh = op->getDh(), dw = op->getDw();
int sh = op->getSh(), sw = op->getSw();
int ph = op->getPh(), pw = op->getPw();
auto odim = op->getOutput()->getDims();
int oh = odim[1], ow = odim[2];
bool PReLU = op->getPReLU();
float paramReLU = op->getParamReLU();
auto opType = op->getOpType();
if (opType == OpType::Conv2dReduce) {
conv2dreduce_kernel(input, bias, output, PReLU, paramReLU, n, h, w,
f, r, s, oh, ow, ph, pw, sh, sw, dh, dw);
} else {
convTranspose2dreduce_kernel(input, bias, output, PReLU, paramReLU,
n, h, w, f, r, s, oh, ow, ph, pw, sh,
sw, dh, dw);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduce, DataType::Float32,
Conv2dReduceCuda, "Conv2dReduce_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose2dReduce, DataType::Float32,
Conv2dReduceCuda, "ConvTranspose2dReduce_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,121 @@
#include "cuda/cuda_common.h"
__global__ void
conv2dreduce_kernel_(float *__restrict__ input, float *__restrict__ bias,
float *__restrict__ output, const bool PReLU,
const float paramReLU, const int n, const int f,
const int h, const int w, const int oh, const int ow,
const int r, const int s, const int ph, const int pw,
const int dh, const int dw, const int sh, const int sw) {
// output shape: (n, oh, ow, f)
// input shape: (n, h, w, f, r, s)
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int out_N_offset = h * w * f, out_H_offset = w * f, out_W_offset = f,
out_F_offset = 1;
const int num = out_N_offset * n;
if (tid < num) {
// output index
int tmptid = tid;
const int nid = tmptid / out_N_offset;
tmptid -= nid * out_N_offset;
const int hid = tmptid / out_H_offset;
tmptid -= hid * out_H_offset;
const int wid = tmptid / out_W_offset;
tmptid -= wid * out_W_offset;
const int fid = tmptid / out_F_offset;
// Input index
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
nchunck = n * hchunk;
float *__restrict__ nfinput = input + nid * nchunck + fid * fchunck;
float imm = 0.0;
const int ihst = hid * sh, iwst = wid * sw;
for (int ri = 0; ri < r; ++ri) {
for (int si = 0; si < s; ++si) {
int ihid = ihst + (ri - r / 2) * dh;
int iwid = iwst + (si - s / 2) * dw;
if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) {
imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s +
si);
}
}
}
if (bias) {
imm += bias[fid];
}
if (PReLU) {
imm = imm > 0.0 ? imm : paramReLU * imm;
}
output[tid] = imm;
}
}
__global__ void convTranspose2dreduce_kernel_(
float *__restrict__ input, float *__restrict__ bias,
float *__restrict__ output, const bool PReLU, const float paramReLU,
const int n, const int f, const int h, const int w, const int oh,
const int ow, const int r, const int s, const int ph, const int pw,
const int dh, const int dw, const int sh, const int sw) {
// assert dh = dw = 1
int nid = blockIdx.x, fid = blockIdx.y;
int hid = threadIdx.x, wid = threadIdx.y;
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
nchunck = n * hchunk;
float *nfinput = input + nid * nchunck + fid * fchunck;
// view as conv, the true ph and pw
int tph = r - ph - 1, tpw = s - pw - 1;
int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1;
if (nid < n && fid < f && hid < oh && wid < ow) {
float imm = 0.0;
int ihst = hid - tph;
int iwst = wid - tpw;
for (int ri = 0; ri < r; ++ri) {
for (int si = 0; si < s; ++si) {
int ihid = ihst + r - ri - 1;
int iwid = iwst + s - si - 1;
if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw &&
(ihid % sh == 0) && (iwid % sw == 0)) {
imm += *(nfinput + (ihid / sh) * hchunk +
(iwid / sw) * wchunk + ri * s + si);
}
}
}
if (bias) {
imm += bias[fid];
}
if (PReLU) {
imm = imm > 0.0 ? imm : paramReLU * imm;
}
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
}
}
namespace infini {
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
float paramReLU, int n, int h, int w, int f, int r,
int s, int oh, int ow, int ph, int pw, int sh, int sw,
int dh, int dw) {
IT_ASSERT(sh == 1 && sw == 1, "conv2dreduce_kernel only support sh=sw=1");
const int blocksize = 512;
const int gridsize = (n * f * oh * ow + blocksize - 1) / blocksize;
cudaStream_t stream(cudaStreamPerThread);
conv2dreduce_kernel_<<<gridsize, blocksize, 0, stream>>>(
input, bias, output, PReLU, paramReLU, n, f, h, w, oh, ow, r, s, ph, pw,
dh, dw, sh, sw);
}
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
bool PReLU, float paramReLU, int n, int h,
int w, int f, int r, int s, int oh, int ow,
int ph, int pw, int sh, int sw, int dh,
int dw) {
dim3 grid(n, f);
dim3 block(oh, ow);
cudaStream_t stream(cudaStreamPerThread);
convTranspose2dreduce_kernel_<<<grid, block, 0, stream>>>(
input, bias, output, PReLU, paramReLU, n, f, h, w, oh, ow, r, s, ph, pw,
dh, dw, sh, sw);
}
} // namespace infini

View File

@ -6,8 +6,9 @@ class CopyCuda : public CudaKernelWithoutConfig {
const RuntimeObj *_context) const override {
auto inData = op->getInputs(0)->getRawDataPtr<void *>();
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
auto cuda = dynamic_cast<const CudaRuntimeObj *>(_context);
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
cudaMemcpyDeviceToDevice);
cudaMemcpyDeviceToDevice, cuda->getStream());
}
};
// reshape/flatten/identity all act as copying from input to output.

View File

@ -0,0 +1,37 @@
#include "operators/bias2prelu.h"
namespace infini {
BiasPReLU::BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
bool PReLU_, float paramPReLU_)
: OperatorObj(OpType::BiasPReLU, {input, bias}, {output}), PReLU(PReLU_),
paramPReLU(paramPReLU_) {
auto dims = input->getDims();
n = dims[0], h = dims[1], w = dims[2], c = dims[3];
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> BiasPReLU::inferShape(const TensorVec &inputs) const {
const Tensor &input = inputs[0];
const Tensor &bias = inputs[1];
auto dims = input->getDims();
int n = dims[0], h = dims[1], w = dims[2], c = dims[3];
int bc = bias->getDims()[0];
if (bc != c)
return {};
return {{{n, h, w, c}}};
}
vector<int> BiasPReLU::getWorkloadVector() const {
return {enum_to_underlying(type), n, h, w, c};
}
vector<int> BiasPReLU::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
} // namespace infini

View File

@ -45,11 +45,25 @@ vector<int> ConvBaseObj::getOpAttrVector() const {
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
}
std::string ConvObj::toString() const {
std::string origin = ConvBaseObj::toString();
std::ostringstream os;
os << (NHWC_layout ? "NHWC" : "NCHW") << origin;
return os.str();
}
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
const Tensor &input = inputs[0];
const Tensor &weight = inputs[1];
n = input->getDims()[0], c = input->getDims()[1], h = input->getDims()[2],
w = input->getDims()[3], f = weight->getDims()[0], r = weight->getDims()[2],
n = input->getDims()[0];
if (NHWC_layout)
c = input->getDims()[3], h = input->getDims()[1],
w = input->getDims()[2];
else
c = input->getDims()[1], h = input->getDims()[2],
w = input->getDims()[3];
f = weight->getDims()[0], r = weight->getDims()[2],
s = weight->getDims()[3];
if (mode == PaddingMode::Same) {
int oh = h / sh;
@ -63,10 +77,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
ActType act, bool nhwc)
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
input, weight),
act(act) {
act(act), NHWC_layout(nhwc) {
if (bias)
IT_TODO_HALT();
setAuxilaryAttributes(PaddingMode::Other);
@ -75,10 +89,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
ActType act, bool nhwc)
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
input, weight),
act(act) {
act(act), NHWC_layout(nhwc) {
if (bias)
IT_TODO_HALT();
setAuxilaryAttributes(mode);
@ -87,21 +101,33 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
const auto &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0];
auto h = input->getDims()[2];
auto w = input->getDims()[3];
int n, h, w, ic, wc;
n = input->getDims()[0];
if (NHWC_layout) {
h = input->getDims()[1];
w = input->getDims()[2];
ic = input->getDims()[3];
} else {
h = input->getDims()[2];
w = input->getDims()[3];
ic = input->getDims()[1];
}
wc = weight->getDims()[1];
auto f = weight->getDims()[0];
auto r = weight->getDims()[2];
auto s = weight->getDims()[3];
int on = n, oc = f;
int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight
if (input->getDims()[1] % weight->getDims()[1] != 0)
if (ic % wc != 0)
return {};
// Set padding size
if (padding == PaddingMode::Other) {
oh = (h - (r - sh) * dh + ph * 2) / sh;
ow = (w - (s - sw) * dw + pw * 2) / sw;
//! fix this
// oh = (h - (r - sh) * dh + ph * 2) / sh;
// ow = (w - (s - sw) * dw + pw * 2) / sw;
oh = (h + ph * 2 - dh * (r - 1) - 1) / sh + 1;
ow = (w + pw * 2 - dw * (s - 1) - 1) / sw + 1;
} else if (padding == PaddingMode::Same) {
oh = h / sh;
ow = w / sw;
@ -113,6 +139,9 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
oh = (h - (r - sh) * dh + ph * 2) / sh;
ow = (w - (s - sw) * dw + pw * 2) / sw;
}
if (NHWC_layout)
return {{{on, oh, ow, oc}}};
return {{{on, oc, oh, ow}}};
}

View File

@ -0,0 +1,96 @@
#include "operators/conv2dreduce.h"
namespace infini {
Conv2dReduceBase::Conv2dReduceBase(OpType opType, Tensor input, Tensor bias,
Tensor output, bool PReLU_, float paramReLU_,
int ph_, int pw_, int sh_, int sw_, int dh_,
int dw_)
: OperatorObj(opType, {input, bias}, {output}), ph(ph_), pw(pw_), sh(sh_),
sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_), paramReLU(paramReLU_) {
// expect input shape is (n, h, w, f, r, s)
auto inputShape = input->getDims();
IT_ASSERT(inputShape.size() == 6);
n = inputShape[0];
h = inputShape[1];
w = inputShape[2];
f = inputShape[3];
r = inputShape[4];
s = inputShape[5];
if (bias) {
auto biasShape = bias->getDims();
IT_ASSERT(biasShape.size() == 1);
IT_ASSERT(biasShape[0] == f);
}
}
std::string Conv2dReduceBase::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
os << "(";
if (inputs.size() == 2) {
os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ",";
}
os << "p=[" << ph << "," << pw << "],";
os << "s=[" << sh << "," << sw << "],";
os << "d=[" << dh << "," << dw << "],";
os << "PReLU=" << (PReLU ? "true" : "false") << ",";
// os << "act=" << enum_to_underlying(act) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
if (inputs[1]) {
os << "bias=" << inputs[1]->getGuid() << ",";
}
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
std::vector<int> Conv2dReduceBase::getWorkloadVector() const {
return {enum_to_underlying(type), n, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
}
std::vector<int> Conv2dReduceBase::getOpAttrVector() const {
return {enum_to_underlying(type), ph, pw, sh, sw, dh, dw};
}
Conv2dReduce::Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias,
Tensor output, bool PReLU_, float paramReLU_,
int ph_, int pw_, int sh_, int sw_, int dh_, int dw_)
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_,
paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
Conv2dReduce::inferShape(const TensorVec &inputs) const {
// const auto &input = inputs[0], &bias = inputs[1];
int on = n, of = f;
int oh = (h + ph * 2 - dh * (r - 1) - 1) / sh + 1;
int ow = (w + pw * 2 - dw * (s - 1) - 1) / sw + 1;
return {{{on, oh, ow, of}}};
}
Conv2dReduceTranspose::Conv2dReduceTranspose(GraphObj *graph, Tensor input,
Tensor bias, Tensor output,
bool PReLU_, float paramReLU_,
int ph_, int pw_, int sh_, int sw_,
int dh_, int dw_)
: Conv2dReduceBase(OpType::ConvTranspose2dReduce, input, bias, output,
PReLU_, paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
IT_ASSERT(dh_ == 1);
IT_ASSERT(dw_ == 1);
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
Conv2dReduceTranspose::inferShape(const TensorVec &inputs) const {
// const auto &input = inputs[0], &bias = inputs[1];
int on = n, of = f;
int oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + 1;
int ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + 1;
return {{{on, oh, ow, of}}};
}
} // namespace infini

View File

@ -0,0 +1,115 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/bias2prelu.h"
#include "operators/conv.h"
#include "operators/conv2dreduce.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/reshape.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
Graph g = make_ref<GraphObj>(cuda);
auto input = g->addTensor({batchSize, 32, 32, 1}); // NCHW
// auto input = g->addTensor({16, 32, 32, 1}); // change to NHWC format
vector<tuple<string, int, int, bool>> configs = {
{"Conv", 56, 5, true}, {"Conv", 12, 1, true},
{"Conv", 12, 3, false}, {"Conv", 12, 3, false},
{"Conv", 12, 3, false}, {"Conv", 12, 3, true},
{"Conv", 56, 1, true}, {"ConvTranposed", 1, 9, false}};
auto x = input;
for (auto &[op, f, r, pRelu] : configs) {
if (r == 5 && op == "Conv") { // for the first conv
auto w = g->addTensor({f, x->getDims()[3], r, r});
x = g->addOp<ConvObj>(x, w, nullptr, r / 2, r / 2, 1, 1, 1, 1,
nullptr, ActType::None, true)
->getOutput();
// if (pRelu) {
// // TODO: Conv_nhwc + Bias+PRelu
// // Alternative: Conv_nchw + Transpose(NCHW->NHWC)+Bias+PRelu
// x = g->addOp<ReluObj>(x, nullptr)->getOutput();
// }
auto bias = g->addTensor({x->getDims()[3]});
x = g->addOp<BiasPReLU>(x, bias, nullptr, pRelu, 0.1)->getOutput();
continue;
}
auto idim = x->getDims();
int n = idim[0], h = idim[1], w = idim[2], c = idim[3];
x = g->addOp<ReshapeObj>(x, nullptr, Shape{1, n * h * w, c})
->getOutput();
auto weight = g->addTensor({1, x->getDims()[2], f * r * r});
x = g->addOp<MatmulObj>(x, weight, nullptr)->getOutput();
x = g->addOp<ReshapeObj>(x, nullptr, Shape{n, h, w, f, r, r})
->getOutput();
auto bias = g->addTensor({f});
if (op == "Conv") {
x = g->addOp<Conv2dReduce>(x, bias, nullptr, pRelu, 0.1, r / 2,
r / 2)
->getOutput();
} else if (op == "ConvTranposed") {
IT_ASSERT(r == 9);
// x = g->addOp<ConvTransposed2dObj>(x, w, nullptr, 3, 3, 4, 4, 1,
// 1,
// 1, 1)
// ->getOutput();
x = g->addOp<Conv2dReduceTranspose>(x, bias, nullptr, pRelu, 0.1, 3,
3, 4, 4)
->getOutput();
} else
IT_ASSERT(false);
}
g->print();
g->dataMalloc();
cuda->run(g, true);
cuda->getPerfTime(g, true);
return g;
};
TEST(Case, fsrcnn_direct_run) {
auto cuda = make_ref<CudaRuntimeObj>();
for (int batch : {1, 16}) {
auto g = createGraph(cuda, batch);
cudaProfilerStart();
printf("E2E time (batch size %d) %.3lf\n", batch,
timeit([&]() { cuda->runWithoutSync(g); },
[&]() { cuda->sync(); }));
cudaProfilerStop();
}
};
TEST(Case, fsrcnn_cuda_graph) {
auto cuda = make_ref<CudaRuntimeObj>();
for (int batch : {1, 16}) {
auto g = createGraph(cuda, batch);
cudaGraph_t graph;
cudaGraphExec_t instance;
checkCudaError(cudaDeviceSynchronize());
cudaStream_t stream = cuda->getStream();
// cudaStreamCaptureStatus log;
// checkCudaError(cudaStreamIsCapturing(stream, &log));
// printf("cudaStreamCaptureStatus %d\n", log);
checkCudaError(
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
cuda->runWithoutSync(g);
checkCudaError(cudaStreamEndCapture(stream, &graph));
checkCudaError(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
cudaProfilerStart();
printf(
"CUDA graph time (batch size %d): %.3lf ms\n", batch,
timeit([&]() { checkCudaError(cudaGraphLaunch(instance, stream)); },
[&]() { cudaStreamSynchronize(stream); }));
cudaProfilerStop();
}
};
} // namespace infini

View File

@ -0,0 +1,90 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/conv2dreduce.h"
#include "operators/matmul.h"
#include "operators/reshape.h"
#include "test.h"
namespace infini {
void testConv2dReduce(
const std::function<void(void *, size_t, DataType)> &generator,
vector<float> ansVec, string mode) {
auto cuda = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cuda);
Runtime cpu = CpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(cpu);
const int n = 1, h = 4, w = 4, c = 3, f = 1, r = 3;
Tensor i0Cpu = gCpu->addTensor({n, h, w, c});
Tensor w0Cpu = gCpu->addTensor({1, c, f * r * r});
Tensor b0Cpu = gCpu->addTensor({f});
gCpu->dataMalloc();
i0Cpu->setData(generator);
w0Cpu->setData(generator);
b0Cpu->setData(generator);
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
Tensor b0Cuda = gCuda->cloneTensor(b0Cpu);
auto x = gCuda->addOp<ReshapeObj>(i0Cuda, nullptr, Shape{1, n * h * w, c})
->getOutput();
x = gCuda->addOp<MatmulObj>(x, w0Cuda, nullptr)->getOutput();
x = gCuda->addOp<ReshapeObj>(x, nullptr, Shape{n, h, w, f, r, r})
->getOutput();
if (mode == "conv") {
x = gCuda
->addOp<Conv2dReduce>(x, b0Cuda, nullptr, false, 0.1, r / 2,
r / 2)
->getOutput();
} else {
x = gCuda
->addOp<Conv2dReduceTranspose>(x, b0Cuda, nullptr, false, 0.1,
r / 2, r / 2, 2, 2)
->getOutput();
}
gCuda->dataMalloc();
cuda->run(gCuda, false);
auto o0Cpu = gCpu->cloneTensor(x);
// o0Cpu->printData();
EXPECT_TRUE(o0Cpu->equalData(ansVec));
}
TEST(Case, conv2dreduce) {
testConv2dReduce(OneGenerator(),
vector<float>{13, 19, 19, 13, 19, 28, 28, 19, 19, 28, 28,
19, 13, 19, 19, 13},
"conv");
testConv2dReduce(IncrementalGenerator(),
vector<float>{1719, 2916, 3699, 2625, 4077, 6480, 7533,
5166, 6993, 10692, 11745, 7866, 4869, 7344,
7965, 5271},
"conv");
testConv2dReduce(OneGenerator(),
vector<float>{4., 7., 4., 7., 4., 7., 4., 7., 13., 7.,
13., 7., 13., 7., 4., 7., 4., 7., 4., 7.,
4., 7., 13., 7., 13., 7., 13., 7., 4., 7.,
4., 7., 4., 7., 4., 7., 13., 7., 13., 7.,
13., 7., 4., 7., 4., 7., 4., 7., 4.},
"convt");
testConv2dReduce(IncrementalGenerator(),
vector<float>{57, 222, 174, 456, 291, 690, 408,
474, 1164, 708, 1632, 942, 2100, 1176,
525, 1158, 642, 1392, 759, 1626, 876,
1410, 3036, 1644, 3504, 1878, 3972, 2112,
993, 2094, 1110, 2328, 1227, 2562, 1344,
2346, 4908, 2580, 5376, 2814, 5844, 3048,
1461, 3030, 1578, 3264, 1695, 3498, 1812},
"convt");
}
} // namespace infini