forked from jiuyuan/InfiniTensor
Compare commits
8 Commits
master
...
case-fsrcn
Author | SHA1 | Date |
---|---|---|
Liyan Zheng | bef4c422a0 | |
Liyan Zheng | 67c06733e6 | |
Liyan Zheng | aa552b5bd2 | |
huangshuhong | 36755c3160 | |
Liyan Zheng | 74e998e262 | |
Liyan Zheng | 7abe7da0e4 | |
huangshuhong | 133513be34 | |
Liyan Zheng | 78425c3209 |
|
@ -40,6 +40,10 @@ enum class OpType {
|
|||
Resize,
|
||||
//
|
||||
MemBound = 300,
|
||||
//
|
||||
Conv2dReduce = 400,
|
||||
ConvTranspose2dReduce,
|
||||
BiasPReLU
|
||||
};
|
||||
|
||||
using KernelAttrs = std::tuple<Device, OpType, DataType>;
|
||||
|
@ -86,6 +90,9 @@ class OpRegistry {
|
|||
FOP(Abs);
|
||||
//
|
||||
FOP(MemBound);
|
||||
FOP(Conv2dReduce);
|
||||
FOP(ConvTranspose2dReduce);
|
||||
FOP(BiasPReLU);
|
||||
default:
|
||||
IT_ASSERT(false);
|
||||
break;
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
#include "operators/bias2prelu.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void bias2prelu_kernel(float *input, float *bias, float *output, int n, int h,
|
||||
int w, int c, bool PReLU, float paramPReLU);
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "operators/conv2dreduce.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
|
||||
float paramReLU, int n, int h, int w, int f, int r,
|
||||
int s, int oh, int ow, int ph, int pw, int sh, int sw,
|
||||
int dh, int dw);
|
||||
|
||||
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
||||
bool PReLU, float paramReLU, int n, int h,
|
||||
int w, int f, int r, int s, int oh, int ow,
|
||||
int ph, int pw, int sh, int sw, int dh,
|
||||
int dw);
|
||||
} // namespace infini
|
|
@ -10,13 +10,14 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
cublasHandle_t cublas;
|
||||
CudaPtr workspace;
|
||||
size_t workspaceSize;
|
||||
cudaStream_t stream;
|
||||
|
||||
public:
|
||||
CUdevice cuDevice;
|
||||
CUcontext newContext;
|
||||
|
||||
public:
|
||||
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
|
||||
CudaRuntimeObj() : RuntimeObj(Device::CUDA), stream(cudaStreamPerThread) {
|
||||
// Prepare for nvrtc. cuCtxCreate should be called befero others.
|
||||
// Otherwise it will result in strange failure, such as cuBLAS failed on
|
||||
// certian inputs.
|
||||
|
@ -26,6 +27,8 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
|
||||
checkCudnnError(cudnnCreate(&cudnn));
|
||||
checkCublasError(cublasCreate(&cublas));
|
||||
checkCublasError(cublasSetStream(cublas, stream));
|
||||
checkCudnnError(cudnnSetStream(cudnn, stream));
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
|
@ -53,6 +56,7 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
cudnnHandle_t cudnnHandle() const { return cudnn; }
|
||||
cublasHandle_t cublasHandle() const { return cublas; }
|
||||
size_t getWorkspaceSize() const { return workspaceSize; }
|
||||
cudaStream_t getStream() const { return stream; }
|
||||
CudaPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class BiasPReLU : public OperatorObj {
|
||||
protected:
|
||||
bool PReLU;
|
||||
float paramPReLU;
|
||||
|
||||
int n, h, w, c;
|
||||
|
||||
public:
|
||||
BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramPReLU_);
|
||||
|
||||
std::string toString() const override { return "Bias2PReLU"; }
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
bool getPReLU() const { return PReLU; }
|
||||
float getParamReLU() const { return paramPReLU; }
|
||||
Tensor getBias() const { return inputs[1]; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -67,19 +67,23 @@ class ConvBaseObj : public OperatorObj {
|
|||
class ConvObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
bool NHWC_layout;
|
||||
|
||||
public:
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
Tensor bias = nullptr, ActType act = ActType::None,
|
||||
bool nhwc = false);
|
||||
// Constructors for setting padding mode
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
ActType act = ActType::None, bool nhwc = false);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
bool getNHWCLayout() const { return NHWC_layout; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class Conv2dReduceBase : public OperatorObj {
|
||||
protected:
|
||||
int ph, pw;
|
||||
int sh, sw;
|
||||
int dh, dw;
|
||||
int n, h, w, f, r, s; // c has been reduced
|
||||
bool PReLU;
|
||||
float paramReLU;
|
||||
|
||||
public:
|
||||
Conv2dReduceBase(OpType opType, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramReLU_, int ph_, int pw_,
|
||||
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
int getDh() const { return dh; }
|
||||
int getDw() const { return dw; }
|
||||
int getPh() const { return ph; }
|
||||
int getPw() const { return pw; }
|
||||
int getSh() const { return sh; }
|
||||
int getSw() const { return sw; }
|
||||
bool getPReLU() const { return PReLU; }
|
||||
float getParamReLU() const { return paramReLU; }
|
||||
|
||||
Tensor getBias() const { return inputs[1]; }
|
||||
|
||||
// optional<vector<Shape>> inferShape(const TensorVec &inputs) const
|
||||
// override;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class Conv2dReduce : public Conv2dReduceBase {
|
||||
public:
|
||||
Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramReLU_, int ph_, int pw_, int sh_ = 1,
|
||||
int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
};
|
||||
|
||||
class Conv2dReduceTranspose : public Conv2dReduceBase {
|
||||
public:
|
||||
Conv2dReduceTranspose(GraphObj *graph, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, float paramReLU_, int ph_,
|
||||
int pw_, int sh_ = 1, int sw_ = 1, int dh_ = 1,
|
||||
int dw_ = 1);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -101,6 +101,7 @@ void RuntimeObj::printProfilingData(double totalTime,
|
|||
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
|
||||
t / totalTime * 100, t / opCnt.at(type));
|
||||
}
|
||||
printf("Total_perf_time: %.3lf ms\n", totalTime);
|
||||
}
|
||||
|
||||
Blob RuntimeObj::allocBlob(size_t size) {
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#include "operators/bias2prelu.h"
|
||||
#include "cuda/cuda_bias2prelu.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class Bias2PReluCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op, const RuntimeObj *_context) const {
|
||||
auto op = as<BiasPReLU>(_op);
|
||||
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bias = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||
float *const output = (op->getOutput()->getRawDataPtr<float *>());
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int n = dim[0], h = dim[1], w = dim[2], c = dim[3];
|
||||
|
||||
bias2prelu_kernel(input, output, bias, n, h, w, c, op->getPReLU(),
|
||||
op->getParamReLU());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::BiasPReLU, DataType::Float32,
|
||||
Bias2PReluCuda, "Bias2PReLU_CUDA_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,34 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
__global__ void bias2prelu_kernel_(
|
||||
float *__restrict__ input,
|
||||
float *__restrict__ bias,
|
||||
float *__restrict__ output,
|
||||
const bool PReLU, const float paramReLU,
|
||||
const int n, const int h, const int w, const int c)
|
||||
{
|
||||
int nid = blockIdx.x, hid = blockIdx.y;
|
||||
int wid = threadIdx.x, cid = threadIdx.y;
|
||||
|
||||
int offset = nid * h * w * c + hid * w * c + wid * c + cid;
|
||||
float imm = bias[cid] + input[offset];
|
||||
if (PReLU) {
|
||||
imm = (imm > 0.0) ? imm : paramReLU * paramReLU;
|
||||
}
|
||||
output[offset] = imm;
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
void bias2prelu_kernel(float *input, float *bias, float *output,
|
||||
int n, int h, int w, int c, bool PReLU, float paramPReLU)
|
||||
{
|
||||
dim3 grid(n, h);
|
||||
dim3 block(w, c);
|
||||
cudaStream_t stream(cudaStreamPerThread);
|
||||
bias2prelu_kernel_<<<grid, block, 0, stream>>>(input, bias, output,
|
||||
PReLU, paramPReLU, n, h, w, c);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -65,6 +65,9 @@ class convCudnn : public Kernel {
|
|||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const auto nhwcLayout = op->getNHWCLayout();
|
||||
const auto inoutLayout =
|
||||
nhwcLayout ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
|
||||
|
||||
int channelsPerGrp = cpg, channels = c;
|
||||
|
||||
|
@ -72,7 +75,7 @@ class convCudnn : public Kernel {
|
|||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
inDesc, inoutLayout, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
|
@ -125,12 +128,17 @@ class convCudnn : public Kernel {
|
|||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, outn, outc,
|
||||
outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, inoutLayout, CUDNN_DATA_FLOAT, outn, outc, outh, outw));
|
||||
if (nhwcLayout) {
|
||||
IT_ASSERT((vector{outn, outh, outw, outc}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
} else {
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
}
|
||||
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#include "operators/conv2dreduce.h"
|
||||
#include "cuda/cuda_conv2dreduce.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class Conv2dReduceCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op, const RuntimeObj *_context) const {
|
||||
auto op = as<Conv2dReduceBase>(_op);
|
||||
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bias = op->getInputs(1)
|
||||
? (op->getInputs(1)->getRawDataPtr<float *>())
|
||||
: nullptr;
|
||||
float *const output = (op->getOutput()->getRawDataPtr<float *>());
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int n = dim[0], h = dim[1], w = dim[2], f = dim[3], r = dim[4],
|
||||
s = dim[5];
|
||||
int dh = op->getDh(), dw = op->getDw();
|
||||
int sh = op->getSh(), sw = op->getSw();
|
||||
int ph = op->getPh(), pw = op->getPw();
|
||||
auto odim = op->getOutput()->getDims();
|
||||
int oh = odim[1], ow = odim[2];
|
||||
bool PReLU = op->getPReLU();
|
||||
float paramReLU = op->getParamReLU();
|
||||
|
||||
auto opType = op->getOpType();
|
||||
|
||||
if (opType == OpType::Conv2dReduce) {
|
||||
conv2dreduce_kernel(input, bias, output, PReLU, paramReLU, n, h, w,
|
||||
f, r, s, oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
} else {
|
||||
convTranspose2dreduce_kernel(input, bias, output, PReLU, paramReLU,
|
||||
n, h, w, f, r, s, oh, ow, ph, pw, sh,
|
||||
sw, dh, dw);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduce, DataType::Float32,
|
||||
Conv2dReduceCuda, "Conv2dReduce_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose2dReduce, DataType::Float32,
|
||||
Conv2dReduceCuda, "ConvTranspose2dReduce_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,121 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
__global__ void
|
||||
conv2dreduce_kernel_(float *__restrict__ input, float *__restrict__ bias,
|
||||
float *__restrict__ output, const bool PReLU,
|
||||
const float paramReLU, const int n, const int f,
|
||||
const int h, const int w, const int oh, const int ow,
|
||||
const int r, const int s, const int ph, const int pw,
|
||||
const int dh, const int dw, const int sh, const int sw) {
|
||||
// output shape: (n, oh, ow, f)
|
||||
// input shape: (n, h, w, f, r, s)
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int out_N_offset = h * w * f, out_H_offset = w * f, out_W_offset = f,
|
||||
out_F_offset = 1;
|
||||
const int num = out_N_offset * n;
|
||||
if (tid < num) {
|
||||
// output index
|
||||
int tmptid = tid;
|
||||
const int nid = tmptid / out_N_offset;
|
||||
tmptid -= nid * out_N_offset;
|
||||
const int hid = tmptid / out_H_offset;
|
||||
tmptid -= hid * out_H_offset;
|
||||
const int wid = tmptid / out_W_offset;
|
||||
tmptid -= wid * out_W_offset;
|
||||
const int fid = tmptid / out_F_offset;
|
||||
|
||||
// Input index
|
||||
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
|
||||
nchunck = n * hchunk;
|
||||
float *__restrict__ nfinput = input + nid * nchunck + fid * fchunck;
|
||||
float imm = 0.0;
|
||||
const int ihst = hid * sh, iwst = wid * sw;
|
||||
for (int ri = 0; ri < r; ++ri) {
|
||||
for (int si = 0; si < s; ++si) {
|
||||
int ihid = ihst + (ri - r / 2) * dh;
|
||||
int iwid = iwst + (si - s / 2) * dw;
|
||||
if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) {
|
||||
imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s +
|
||||
si);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias) {
|
||||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : paramReLU * imm;
|
||||
}
|
||||
output[tid] = imm;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void convTranspose2dreduce_kernel_(
|
||||
float *__restrict__ input, float *__restrict__ bias,
|
||||
float *__restrict__ output, const bool PReLU, const float paramReLU,
|
||||
const int n, const int f, const int h, const int w, const int oh,
|
||||
const int ow, const int r, const int s, const int ph, const int pw,
|
||||
const int dh, const int dw, const int sh, const int sw) {
|
||||
// assert dh = dw = 1
|
||||
int nid = blockIdx.x, fid = blockIdx.y;
|
||||
int hid = threadIdx.x, wid = threadIdx.y;
|
||||
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
|
||||
nchunck = n * hchunk;
|
||||
float *nfinput = input + nid * nchunck + fid * fchunck;
|
||||
// view as conv, the true ph and pw
|
||||
int tph = r - ph - 1, tpw = s - pw - 1;
|
||||
int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1;
|
||||
if (nid < n && fid < f && hid < oh && wid < ow) {
|
||||
float imm = 0.0;
|
||||
int ihst = hid - tph;
|
||||
int iwst = wid - tpw;
|
||||
for (int ri = 0; ri < r; ++ri) {
|
||||
for (int si = 0; si < s; ++si) {
|
||||
int ihid = ihst + r - ri - 1;
|
||||
int iwid = iwst + s - si - 1;
|
||||
if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw &&
|
||||
(ihid % sh == 0) && (iwid % sw == 0)) {
|
||||
imm += *(nfinput + (ihid / sh) * hchunk +
|
||||
(iwid / sw) * wchunk + ri * s + si);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias) {
|
||||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : paramReLU * imm;
|
||||
}
|
||||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
|
||||
float paramReLU, int n, int h, int w, int f, int r,
|
||||
int s, int oh, int ow, int ph, int pw, int sh, int sw,
|
||||
int dh, int dw) {
|
||||
IT_ASSERT(sh == 1 && sw == 1, "conv2dreduce_kernel only support sh=sw=1");
|
||||
const int blocksize = 512;
|
||||
const int gridsize = (n * f * oh * ow + blocksize - 1) / blocksize;
|
||||
|
||||
cudaStream_t stream(cudaStreamPerThread);
|
||||
conv2dreduce_kernel_<<<gridsize, blocksize, 0, stream>>>(
|
||||
input, bias, output, PReLU, paramReLU, n, f, h, w, oh, ow, r, s, ph, pw,
|
||||
dh, dw, sh, sw);
|
||||
}
|
||||
|
||||
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
||||
bool PReLU, float paramReLU, int n, int h,
|
||||
int w, int f, int r, int s, int oh, int ow,
|
||||
int ph, int pw, int sh, int sw, int dh,
|
||||
int dw) {
|
||||
dim3 grid(n, f);
|
||||
dim3 block(oh, ow);
|
||||
cudaStream_t stream(cudaStreamPerThread);
|
||||
convTranspose2dreduce_kernel_<<<grid, block, 0, stream>>>(
|
||||
input, bias, output, PReLU, paramReLU, n, f, h, w, oh, ow, r, s, ph, pw,
|
||||
dh, dw, sh, sw);
|
||||
}
|
||||
} // namespace infini
|
|
@ -6,8 +6,9 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
const RuntimeObj *_context) const override {
|
||||
auto inData = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
auto cuda = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpyDeviceToDevice, cuda->getStream());
|
||||
}
|
||||
};
|
||||
// reshape/flatten/identity all act as copying from input to output.
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/bias2prelu.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
BiasPReLU::BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramPReLU_)
|
||||
: OperatorObj(OpType::BiasPReLU, {input, bias}, {output}), PReLU(PReLU_),
|
||||
paramPReLU(paramPReLU_) {
|
||||
auto dims = input->getDims();
|
||||
n = dims[0], h = dims[1], w = dims[2], c = dims[3];
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> BiasPReLU::inferShape(const TensorVec &inputs) const {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &bias = inputs[1];
|
||||
|
||||
auto dims = input->getDims();
|
||||
int n = dims[0], h = dims[1], w = dims[2], c = dims[3];
|
||||
int bc = bias->getDims()[0];
|
||||
|
||||
if (bc != c)
|
||||
return {};
|
||||
|
||||
return {{{n, h, w, c}}};
|
||||
}
|
||||
|
||||
vector<int> BiasPReLU::getWorkloadVector() const {
|
||||
return {enum_to_underlying(type), n, h, w, c};
|
||||
}
|
||||
|
||||
vector<int> BiasPReLU::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -45,11 +45,25 @@ vector<int> ConvBaseObj::getOpAttrVector() const {
|
|||
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
std::string ConvObj::toString() const {
|
||||
std::string origin = ConvBaseObj::toString();
|
||||
std::ostringstream os;
|
||||
os << (NHWC_layout ? "NHWC" : "NCHW") << origin;
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], c = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3], f = weight->getDims()[0], r = weight->getDims()[2],
|
||||
|
||||
n = input->getDims()[0];
|
||||
if (NHWC_layout)
|
||||
c = input->getDims()[3], h = input->getDims()[1],
|
||||
w = input->getDims()[2];
|
||||
else
|
||||
c = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3];
|
||||
f = weight->getDims()[0], r = weight->getDims()[2],
|
||||
s = weight->getDims()[3];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
|
@ -63,10 +77,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
|||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
ActType act, bool nhwc)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
act(act), NHWC_layout(nhwc) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
|
@ -75,10 +89,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
ActType act, bool nhwc)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
act(act), NHWC_layout(nhwc) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
|
@ -87,21 +101,33 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
|
||||
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
int n, h, w, ic, wc;
|
||||
n = input->getDims()[0];
|
||||
if (NHWC_layout) {
|
||||
h = input->getDims()[1];
|
||||
w = input->getDims()[2];
|
||||
ic = input->getDims()[3];
|
||||
} else {
|
||||
h = input->getDims()[2];
|
||||
w = input->getDims()[3];
|
||||
ic = input->getDims()[1];
|
||||
}
|
||||
wc = weight->getDims()[1];
|
||||
auto f = weight->getDims()[0];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
if (input->getDims()[1] % weight->getDims()[1] != 0)
|
||||
if (ic % wc != 0)
|
||||
return {};
|
||||
// Set padding size
|
||||
if (padding == PaddingMode::Other) {
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
//! fix this
|
||||
// oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
// ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
oh = (h + ph * 2 - dh * (r - 1) - 1) / sh + 1;
|
||||
ow = (w + pw * 2 - dw * (s - 1) - 1) / sw + 1;
|
||||
} else if (padding == PaddingMode::Same) {
|
||||
oh = h / sh;
|
||||
ow = w / sw;
|
||||
|
@ -113,6 +139,9 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
|||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
}
|
||||
if (NHWC_layout)
|
||||
return {{{on, oh, ow, oc}}};
|
||||
|
||||
return {{{on, oc, oh, ow}}};
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
#include "operators/conv2dreduce.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
Conv2dReduceBase::Conv2dReduceBase(OpType opType, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_, int dh_,
|
||||
int dw_)
|
||||
: OperatorObj(opType, {input, bias}, {output}), ph(ph_), pw(pw_), sh(sh_),
|
||||
sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_), paramReLU(paramReLU_) {
|
||||
// expect input shape is (n, h, w, f, r, s)
|
||||
auto inputShape = input->getDims();
|
||||
IT_ASSERT(inputShape.size() == 6);
|
||||
n = inputShape[0];
|
||||
h = inputShape[1];
|
||||
w = inputShape[2];
|
||||
f = inputShape[3];
|
||||
r = inputShape[4];
|
||||
s = inputShape[5];
|
||||
|
||||
if (bias) {
|
||||
auto biasShape = bias->getDims();
|
||||
IT_ASSERT(biasShape.size() == 1);
|
||||
IT_ASSERT(biasShape[0] == f);
|
||||
}
|
||||
}
|
||||
|
||||
std::string Conv2dReduceBase::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
if (inputs.size() == 2) {
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << vecToString(inputs[1]->getDims()) << ",";
|
||||
}
|
||||
os << "p=[" << ph << "," << pw << "],";
|
||||
os << "s=[" << sh << "," << sw << "],";
|
||||
os << "d=[" << dh << "," << dw << "],";
|
||||
os << "PReLU=" << (PReLU ? "true" : "false") << ",";
|
||||
// os << "act=" << enum_to_underlying(act) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
if (inputs[1]) {
|
||||
os << "bias=" << inputs[1]->getGuid() << ",";
|
||||
}
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
std::vector<int> Conv2dReduceBase::getWorkloadVector() const {
|
||||
return {enum_to_underlying(type), n, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
std::vector<int> Conv2dReduceBase::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type), ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
Conv2dReduce::Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_, int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_,
|
||||
paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
Conv2dReduce::inferShape(const TensorVec &inputs) const {
|
||||
// const auto &input = inputs[0], &bias = inputs[1];
|
||||
int on = n, of = f;
|
||||
int oh = (h + ph * 2 - dh * (r - 1) - 1) / sh + 1;
|
||||
int ow = (w + pw * 2 - dw * (s - 1) - 1) / sw + 1;
|
||||
|
||||
return {{{on, oh, ow, of}}};
|
||||
}
|
||||
|
||||
Conv2dReduceTranspose::Conv2dReduceTranspose(GraphObj *graph, Tensor input,
|
||||
Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_,
|
||||
int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::ConvTranspose2dReduce, input, bias, output,
|
||||
PReLU_, paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
IT_ASSERT(dh_ == 1);
|
||||
IT_ASSERT(dw_ == 1);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
Conv2dReduceTranspose::inferShape(const TensorVec &inputs) const {
|
||||
// const auto &input = inputs[0], &bias = inputs[1];
|
||||
int on = n, of = f;
|
||||
int oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + 1;
|
||||
int ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + 1;
|
||||
|
||||
return {{{on, oh, ow, of}}};
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,115 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/bias2prelu.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/conv2dreduce.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
|
||||
Graph g = make_ref<GraphObj>(cuda);
|
||||
auto input = g->addTensor({batchSize, 32, 32, 1}); // NCHW
|
||||
// auto input = g->addTensor({16, 32, 32, 1}); // change to NHWC format
|
||||
vector<tuple<string, int, int, bool>> configs = {
|
||||
{"Conv", 56, 5, true}, {"Conv", 12, 1, true},
|
||||
{"Conv", 12, 3, false}, {"Conv", 12, 3, false},
|
||||
{"Conv", 12, 3, false}, {"Conv", 12, 3, true},
|
||||
{"Conv", 56, 1, true}, {"ConvTranposed", 1, 9, false}};
|
||||
auto x = input;
|
||||
for (auto &[op, f, r, pRelu] : configs) {
|
||||
if (r == 5 && op == "Conv") { // for the first conv
|
||||
auto w = g->addTensor({f, x->getDims()[3], r, r});
|
||||
x = g->addOp<ConvObj>(x, w, nullptr, r / 2, r / 2, 1, 1, 1, 1,
|
||||
nullptr, ActType::None, true)
|
||||
->getOutput();
|
||||
// if (pRelu) {
|
||||
// // TODO: Conv_nhwc + Bias+PRelu
|
||||
// // Alternative: Conv_nchw + Transpose(NCHW->NHWC)+Bias+PRelu
|
||||
// x = g->addOp<ReluObj>(x, nullptr)->getOutput();
|
||||
// }
|
||||
auto bias = g->addTensor({x->getDims()[3]});
|
||||
x = g->addOp<BiasPReLU>(x, bias, nullptr, pRelu, 0.1)->getOutput();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto idim = x->getDims();
|
||||
int n = idim[0], h = idim[1], w = idim[2], c = idim[3];
|
||||
x = g->addOp<ReshapeObj>(x, nullptr, Shape{1, n * h * w, c})
|
||||
->getOutput();
|
||||
auto weight = g->addTensor({1, x->getDims()[2], f * r * r});
|
||||
x = g->addOp<MatmulObj>(x, weight, nullptr)->getOutput();
|
||||
x = g->addOp<ReshapeObj>(x, nullptr, Shape{n, h, w, f, r, r})
|
||||
->getOutput();
|
||||
auto bias = g->addTensor({f});
|
||||
if (op == "Conv") {
|
||||
x = g->addOp<Conv2dReduce>(x, bias, nullptr, pRelu, 0.1, r / 2,
|
||||
r / 2)
|
||||
->getOutput();
|
||||
} else if (op == "ConvTranposed") {
|
||||
IT_ASSERT(r == 9);
|
||||
// x = g->addOp<ConvTransposed2dObj>(x, w, nullptr, 3, 3, 4, 4, 1,
|
||||
// 1,
|
||||
// 1, 1)
|
||||
// ->getOutput();
|
||||
x = g->addOp<Conv2dReduceTranspose>(x, bias, nullptr, pRelu, 0.1, 3,
|
||||
3, 4, 4)
|
||||
->getOutput();
|
||||
} else
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
g->print();
|
||||
g->dataMalloc();
|
||||
cuda->run(g, true);
|
||||
cuda->getPerfTime(g, true);
|
||||
return g;
|
||||
};
|
||||
|
||||
TEST(Case, fsrcnn_direct_run) {
|
||||
auto cuda = make_ref<CudaRuntimeObj>();
|
||||
for (int batch : {1, 16}) {
|
||||
auto g = createGraph(cuda, batch);
|
||||
cudaProfilerStart();
|
||||
printf("E2E time (batch size %d) %.3lf\n", batch,
|
||||
timeit([&]() { cuda->runWithoutSync(g); },
|
||||
[&]() { cuda->sync(); }));
|
||||
cudaProfilerStop();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(Case, fsrcnn_cuda_graph) {
|
||||
auto cuda = make_ref<CudaRuntimeObj>();
|
||||
for (int batch : {1, 16}) {
|
||||
auto g = createGraph(cuda, batch);
|
||||
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
|
||||
checkCudaError(cudaDeviceSynchronize());
|
||||
cudaStream_t stream = cuda->getStream();
|
||||
// cudaStreamCaptureStatus log;
|
||||
// checkCudaError(cudaStreamIsCapturing(stream, &log));
|
||||
// printf("cudaStreamCaptureStatus %d\n", log);
|
||||
checkCudaError(
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
|
||||
cuda->runWithoutSync(g);
|
||||
checkCudaError(cudaStreamEndCapture(stream, &graph));
|
||||
checkCudaError(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
|
||||
|
||||
cudaProfilerStart();
|
||||
printf(
|
||||
"CUDA graph time (batch size %d): %.3lf ms\n", batch,
|
||||
timeit([&]() { checkCudaError(cudaGraphLaunch(instance, stream)); },
|
||||
[&]() { cudaStreamSynchronize(stream); }));
|
||||
cudaProfilerStop();
|
||||
}
|
||||
};
|
||||
} // namespace infini
|
|
@ -0,0 +1,90 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv2dreduce.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/reshape.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConv2dReduce(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec, string mode) {
|
||||
auto cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
Runtime cpu = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
||||
const int n = 1, h = 4, w = 4, c = 3, f = 1, r = 3;
|
||||
|
||||
Tensor i0Cpu = gCpu->addTensor({n, h, w, c});
|
||||
Tensor w0Cpu = gCpu->addTensor({1, c, f * r * r});
|
||||
Tensor b0Cpu = gCpu->addTensor({f});
|
||||
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
b0Cpu->setData(generator);
|
||||
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
Tensor b0Cuda = gCuda->cloneTensor(b0Cpu);
|
||||
|
||||
auto x = gCuda->addOp<ReshapeObj>(i0Cuda, nullptr, Shape{1, n * h * w, c})
|
||||
->getOutput();
|
||||
x = gCuda->addOp<MatmulObj>(x, w0Cuda, nullptr)->getOutput();
|
||||
x = gCuda->addOp<ReshapeObj>(x, nullptr, Shape{n, h, w, f, r, r})
|
||||
->getOutput();
|
||||
if (mode == "conv") {
|
||||
x = gCuda
|
||||
->addOp<Conv2dReduce>(x, b0Cuda, nullptr, false, 0.1, r / 2,
|
||||
r / 2)
|
||||
->getOutput();
|
||||
} else {
|
||||
x = gCuda
|
||||
->addOp<Conv2dReduceTranspose>(x, b0Cuda, nullptr, false, 0.1,
|
||||
r / 2, r / 2, 2, 2)
|
||||
->getOutput();
|
||||
}
|
||||
|
||||
gCuda->dataMalloc();
|
||||
cuda->run(gCuda, false);
|
||||
|
||||
auto o0Cpu = gCpu->cloneTensor(x);
|
||||
// o0Cpu->printData();
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(Case, conv2dreduce) {
|
||||
testConv2dReduce(OneGenerator(),
|
||||
vector<float>{13, 19, 19, 13, 19, 28, 28, 19, 19, 28, 28,
|
||||
19, 13, 19, 19, 13},
|
||||
"conv");
|
||||
testConv2dReduce(IncrementalGenerator(),
|
||||
vector<float>{1719, 2916, 3699, 2625, 4077, 6480, 7533,
|
||||
5166, 6993, 10692, 11745, 7866, 4869, 7344,
|
||||
7965, 5271},
|
||||
"conv");
|
||||
testConv2dReduce(OneGenerator(),
|
||||
vector<float>{4., 7., 4., 7., 4., 7., 4., 7., 13., 7.,
|
||||
13., 7., 13., 7., 4., 7., 4., 7., 4., 7.,
|
||||
4., 7., 13., 7., 13., 7., 13., 7., 4., 7.,
|
||||
4., 7., 4., 7., 4., 7., 13., 7., 13., 7.,
|
||||
13., 7., 4., 7., 4., 7., 4., 7., 4.},
|
||||
"convt");
|
||||
testConv2dReduce(IncrementalGenerator(),
|
||||
vector<float>{57, 222, 174, 456, 291, 690, 408,
|
||||
474, 1164, 708, 1632, 942, 2100, 1176,
|
||||
525, 1158, 642, 1392, 759, 1626, 876,
|
||||
1410, 3036, 1644, 3504, 1878, 3972, 2112,
|
||||
993, 2094, 1110, 2328, 1227, 2562, 1344,
|
||||
2346, 4908, 2580, 5376, 2814, 5844, 3048,
|
||||
1461, 3030, 1578, 3264, 1695, 3498, 1812},
|
||||
"convt");
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue