forked from jiuyuan/InfiniTensor
convNHWC+pReLU+biasPReLU
This commit is contained in:
parent
74e998e262
commit
36755c3160
|
@ -42,7 +42,8 @@ enum class OpType {
|
|||
MemBound = 300,
|
||||
//
|
||||
Conv2dReduce = 400,
|
||||
ConvTranspose2dReduce
|
||||
ConvTranspose2dReduce,
|
||||
BiasPReLU
|
||||
};
|
||||
|
||||
using KernelAttrs = std::tuple<Device, OpType, DataType>;
|
||||
|
@ -91,6 +92,7 @@ class OpRegistry {
|
|||
FOP(MemBound);
|
||||
FOP(Conv2dReduce);
|
||||
FOP(ConvTranspose2dReduce);
|
||||
FOP(BiasPReLU);
|
||||
default:
|
||||
IT_ASSERT(false);
|
||||
break;
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
#include "operators/bias2prelu.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void bias2prelu_kernel(float *input, float *bias, float *output, int n, int h,
|
||||
int w, int c, bool PReLU, float paramPReLU);
|
||||
|
||||
}
|
|
@ -4,12 +4,13 @@
|
|||
namespace infini {
|
||||
|
||||
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
|
||||
int n, int h, int w, int f, int r, int s, int oh,
|
||||
int ow, int ph, int pw, int sh, int sw, int dh,
|
||||
int dw);
|
||||
float paramReLU, int n, int h, int w, int f, int r,
|
||||
int s, int oh, int ow, int ph, int pw, int sh, int sw,
|
||||
int dh, int dw);
|
||||
|
||||
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
||||
bool PReLU, int n, int h, int w, int f, int r,
|
||||
int s, int oh, int ow, int ph, int pw, int sh,
|
||||
int sw, int dh, int dw);
|
||||
bool PReLU, float paramReLU, int n, int h,
|
||||
int w, int f, int r, int s, int oh, int ow,
|
||||
int ph, int pw, int sh, int sw, int dh,
|
||||
int dw);
|
||||
} // namespace infini
|
|
@ -0,0 +1,32 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class BiasPReLU : public OperatorObj {
|
||||
protected:
|
||||
bool PReLU;
|
||||
float paramPReLU;
|
||||
|
||||
int n, h, w, c;
|
||||
|
||||
public:
|
||||
BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramPReLU_);
|
||||
|
||||
std::string toString() const override { return "Bias2PReLU"; }
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
bool getPReLU() const { return PReLU; }
|
||||
float getParamReLU() const { return paramPReLU; }
|
||||
Tensor getBias() const { return inputs[1]; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -67,19 +67,23 @@ class ConvBaseObj : public OperatorObj {
|
|||
class ConvObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
bool NHWC_layout;
|
||||
|
||||
public:
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
Tensor bias = nullptr, ActType act = ActType::None,
|
||||
bool nhwc = false);
|
||||
// Constructors for setting padding mode
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
ActType act = ActType::None, bool nhwc = false);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
bool getNHWCLayout() const { return NHWC_layout; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -10,11 +10,12 @@ class Conv2dReduceBase : public OperatorObj {
|
|||
int dh, dw;
|
||||
int n, h, w, f, r, s; // c has been reduced
|
||||
bool PReLU;
|
||||
float paramReLU;
|
||||
|
||||
public:
|
||||
Conv2dReduceBase(OpType opType, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, int ph_, int pw_, int sh_ = 1, int sw_ = 1,
|
||||
int dh_ = 1, int dw_ = 1);
|
||||
bool PReLU_, float paramReLU_, int ph_, int pw_,
|
||||
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
|
@ -27,6 +28,7 @@ class Conv2dReduceBase : public OperatorObj {
|
|||
int getSh() const { return sh; }
|
||||
int getSw() const { return sw; }
|
||||
bool getPReLU() const { return PReLU; }
|
||||
float getParamReLU() const { return paramReLU; }
|
||||
|
||||
Tensor getBias() const { return inputs[1]; }
|
||||
|
||||
|
@ -41,16 +43,17 @@ class Conv2dReduceBase : public OperatorObj {
|
|||
class Conv2dReduce : public Conv2dReduceBase {
|
||||
public:
|
||||
Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, int ph_, int pw_, int sh_ = 1, int sw_ = 1,
|
||||
int dh_ = 1, int dw_ = 1);
|
||||
bool PReLU_, float paramReLU_, int ph_, int pw_, int sh_ = 1,
|
||||
int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
};
|
||||
|
||||
class Conv2dReduceTranspose : public Conv2dReduceBase {
|
||||
public:
|
||||
Conv2dReduceTranspose(GraphObj *graph, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, int ph_, int pw_,
|
||||
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
Tensor output, bool PReLU_, float paramReLU_, int ph_,
|
||||
int pw_, int sh_ = 1, int sw_ = 1, int dh_ = 1,
|
||||
int dw_ = 1);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -0,0 +1,24 @@
|
|||
#include "operators/bias2prelu.h"
|
||||
#include "cuda/cuda_bias2prelu.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class Bias2PReluCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op, const RuntimeObj *_context) const {
|
||||
auto op = as<BiasPReLU>(_op);
|
||||
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bias = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||
float *const output = (op->getOutput()->getRawDataPtr<float *>());
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int n = dim[0], h = dim[1], w = dim[2], c = dim[3];
|
||||
|
||||
bias2prelu_kernel(input, output, bias, n, h, w, c, op->getPReLU(),
|
||||
op->getParamReLU());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::BiasPReLU, DataType::Float32,
|
||||
Bias2PReluCuda, "Bias2PReLU_CUDA_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,34 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
__global__ void bias2prelu_kernel_(
|
||||
float *__restrict__ input,
|
||||
float *__restrict__ bias,
|
||||
float *__restrict__ output,
|
||||
const bool PReLU, const float paramReLU,
|
||||
const int n, const int h, const int w, const int c)
|
||||
{
|
||||
int nid = blockIdx.x, hid = blockIdx.y;
|
||||
int wid = threadIdx.x, cid = threadIdx.y;
|
||||
|
||||
int offset = nid * h * w * c + hid * w * c + wid * c + cid;
|
||||
float imm = bias[cid] + input[offset];
|
||||
if (PReLU) {
|
||||
imm = (imm > 0.0) ? imm : paramReLU * paramReLU;
|
||||
}
|
||||
output[offset] = imm;
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
void bias2prelu_kernel(float *input, float *bias, float *output,
|
||||
int n, int h, int w, int c, bool PReLU, float paramPReLU)
|
||||
{
|
||||
dim3 grid(n, h);
|
||||
dim3 block(w, c);
|
||||
cudaStream_t stream(cudaStreamPerThread);
|
||||
bias2prelu_kernel_<<<grid, block, 0, stream>>>(input, bias, output,
|
||||
PReLU, paramPReLU, n, h, w, c);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -65,6 +65,9 @@ class convCudnn : public Kernel {
|
|||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const auto nhwcLayout = op->getNHWCLayout();
|
||||
const auto inoutLayout =
|
||||
nhwcLayout ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
|
||||
|
||||
int channelsPerGrp = cpg, channels = c;
|
||||
|
||||
|
@ -72,7 +75,7 @@ class convCudnn : public Kernel {
|
|||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
inDesc, inoutLayout, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
|
@ -125,12 +128,17 @@ class convCudnn : public Kernel {
|
|||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, outn, outc,
|
||||
outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, inoutLayout, CUDNN_DATA_FLOAT, outn, outc, outh, outw));
|
||||
if (nhwcLayout) {
|
||||
IT_ASSERT((vector{outn, outh, outw, outc}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
} else {
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
}
|
||||
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
|
|
|
@ -23,15 +23,17 @@ class Conv2dReduceCuda : public CudaKernelWithoutConfig {
|
|||
auto odim = op->getOutput()->getDims();
|
||||
int oh = odim[1], ow = odim[2];
|
||||
bool PReLU = op->getPReLU();
|
||||
float paramReLU = op->getParamReLU();
|
||||
|
||||
auto opType = op->getOpType();
|
||||
|
||||
if (opType == OpType::Conv2dReduce) {
|
||||
conv2dreduce_kernel(input, bias, output, PReLU, n, h, w, f, r, s,
|
||||
oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
conv2dreduce_kernel(input, bias, output, PReLU, paramReLU, n, h, w,
|
||||
f, r, s, oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
} else {
|
||||
convTranspose2dreduce_kernel(input, bias, output, PReLU, n, h, w, f,
|
||||
r, s, oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
convTranspose2dreduce_kernel(input, bias, output, PReLU, paramReLU,
|
||||
n, h, w, f, r, s, oh, ow, ph, pw, sh,
|
||||
sw, dh, dw);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -4,7 +4,7 @@ __global__ void conv2dreduce_kernel_(
|
|||
float *__restrict__ input,
|
||||
float *__restrict__ bias,
|
||||
float *__restrict__ output,
|
||||
const bool PReLU, const int n, const int f, const int h, const int w,
|
||||
const bool PReLU, const float paramReLU, const int n, const int f, const int h, const int w,
|
||||
const int oh, const int ow, const int r, const int s, const int ph,
|
||||
const int pw, const int dh, const int dw, const int sh, const int sw)
|
||||
{
|
||||
|
@ -31,7 +31,7 @@ __global__ void conv2dreduce_kernel_(
|
|||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : 0.0;
|
||||
imm = imm > 0.0 ? imm : paramReLU * imm;
|
||||
}
|
||||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ __global__ void convTranspose2dreduce_kernel_(
|
|||
float *__restrict__ input,
|
||||
float *__restrict__ bias,
|
||||
float *__restrict__ output,
|
||||
const bool PReLU, const int n, const int f, const int h, const int w,
|
||||
const bool PReLU, const float paramReLU, const int n, const int f, const int h, const int w,
|
||||
const int oh, const int ow, const int r, const int s, const int ph,
|
||||
const int pw, const int dh, const int dw, const int sh, const int sw)
|
||||
{
|
||||
|
@ -73,7 +73,7 @@ __global__ void convTranspose2dreduce_kernel_(
|
|||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : 0.0;
|
||||
imm = imm > 0.0 ? imm : paramReLU * imm;
|
||||
}
|
||||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
|
@ -82,26 +82,26 @@ __global__ void convTranspose2dreduce_kernel_(
|
|||
namespace infini {
|
||||
|
||||
void conv2dreduce_kernel(float *input, float *bias, float *output,
|
||||
bool PReLU, int n, int h, int w, int f,
|
||||
bool PReLU, float paramReLU, int n, int h, int w, int f,
|
||||
int r, int s, int oh, int ow, int ph, int pw, int sh,
|
||||
int sw, int dh, int dw)
|
||||
{
|
||||
dim3 grid(n, f);
|
||||
dim3 block(oh, ow);
|
||||
cudaStream_t stream(cudaStreamPerThread);
|
||||
conv2dreduce_kernel_<<<grid, block, 0, stream>>>(input, bias, output, PReLU, n, f, h, w,
|
||||
conv2dreduce_kernel_<<<grid, block, 0, stream>>>(input, bias, output, PReLU, paramReLU, n, f, h, w,
|
||||
oh, ow, r, s, ph, pw, dh, dw, sh, sw);
|
||||
}
|
||||
|
||||
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
||||
bool PReLU, int n, int h, int w, int f,
|
||||
bool PReLU, float paramReLU, int n, int h, int w, int f,
|
||||
int r, int s, int oh, int ow, int ph, int pw, int sh,
|
||||
int sw, int dh, int dw)
|
||||
{
|
||||
dim3 grid(n, f);
|
||||
dim3 block(oh, ow);
|
||||
cudaStream_t stream(cudaStreamPerThread);
|
||||
convTranspose2dreduce_kernel_<<<grid, block, 0,stream>>>(input, bias, output, PReLU, n, f, h, w,
|
||||
convTranspose2dreduce_kernel_<<<grid, block, 0,stream>>>(input, bias, output, PReLU, paramReLU, n, f, h, w,
|
||||
oh, ow, r, s, ph, pw, dh, dw, sh, sw);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/bias2prelu.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
BiasPReLU::BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramPReLU_)
|
||||
: OperatorObj(OpType::BiasPReLU, {input, bias}, {output}), PReLU(PReLU_),
|
||||
paramPReLU(paramPReLU_) {
|
||||
auto dims = input->getDims();
|
||||
n = dims[0], h = dims[1], w = dims[2], c = dims[3];
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> BiasPReLU::inferShape(const TensorVec &inputs) const {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &bias = inputs[1];
|
||||
|
||||
auto dims = input->getDims();
|
||||
int n = dims[0], h = dims[1], w = dims[2], c = dims[3];
|
||||
int bc = bias->getDims()[0];
|
||||
|
||||
if (bc != c)
|
||||
return {};
|
||||
|
||||
return {{{n, h, w, c}}};
|
||||
}
|
||||
|
||||
vector<int> BiasPReLU::getWorkloadVector() const {
|
||||
return {enum_to_underlying(type), n, h, w, c};
|
||||
}
|
||||
|
||||
vector<int> BiasPReLU::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -45,11 +45,25 @@ vector<int> ConvBaseObj::getOpAttrVector() const {
|
|||
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
std::string ConvObj::toString() const {
|
||||
std::string origin = ConvBaseObj::toString();
|
||||
std::ostringstream os;
|
||||
os << (NHWC_layout ? "NHWC" : "NCHW") << origin;
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], c = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3], f = weight->getDims()[0], r = weight->getDims()[2],
|
||||
|
||||
n = input->getDims()[0];
|
||||
if (NHWC_layout)
|
||||
c = input->getDims()[3], h = input->getDims()[1],
|
||||
w = input->getDims()[2];
|
||||
else
|
||||
c = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3];
|
||||
f = weight->getDims()[0], r = weight->getDims()[2],
|
||||
s = weight->getDims()[3];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
|
@ -63,10 +77,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
|||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
ActType act, bool nhwc)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
act(act), NHWC_layout(nhwc) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
|
@ -75,10 +89,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
ActType act, bool nhwc)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
act(act), NHWC_layout(nhwc) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
|
@ -87,16 +101,25 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
|
||||
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
int n, h, w, ic, wc;
|
||||
n = input->getDims()[0];
|
||||
if (NHWC_layout) {
|
||||
h = input->getDims()[1];
|
||||
w = input->getDims()[2];
|
||||
ic = input->getDims()[3];
|
||||
} else {
|
||||
h = input->getDims()[2];
|
||||
w = input->getDims()[3];
|
||||
ic = input->getDims()[1];
|
||||
}
|
||||
wc = weight->getDims()[1];
|
||||
auto f = weight->getDims()[0];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
if (input->getDims()[1] % weight->getDims()[1] != 0)
|
||||
if (ic % wc != 0)
|
||||
return {};
|
||||
// Set padding size
|
||||
if (padding == PaddingMode::Other) {
|
||||
|
@ -116,6 +139,9 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
|||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
}
|
||||
if (NHWC_layout)
|
||||
return {{{on, oh, ow, oc}}};
|
||||
|
||||
return {{{on, oc, oh, ow}}};
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,11 @@
|
|||
namespace infini {
|
||||
|
||||
Conv2dReduceBase::Conv2dReduceBase(OpType opType, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, int ph_, int pw_,
|
||||
int sh_, int sw_, int dh_, int dw_)
|
||||
Tensor output, bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_, int dh_,
|
||||
int dw_)
|
||||
: OperatorObj(opType, {input, bias}, {output}), ph(ph_), pw(pw_), sh(sh_),
|
||||
sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_) {
|
||||
sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_), paramReLU(paramReLU_) {
|
||||
// expect input shape is (n, h, w, f, r, s)
|
||||
auto inputShape = input->getDims();
|
||||
IT_ASSERT(inputShape.size() == 6);
|
||||
|
@ -54,10 +55,10 @@ std::vector<int> Conv2dReduceBase::getOpAttrVector() const {
|
|||
}
|
||||
|
||||
Conv2dReduce::Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, int ph_, int pw_,
|
||||
int sh_, int sw_, int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_, ph_,
|
||||
pw_, sh_, sw_, dh_, dw_) {
|
||||
Tensor output, bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_, int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_,
|
||||
paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -73,10 +74,11 @@ Conv2dReduce::inferShape(const TensorVec &inputs) const {
|
|||
|
||||
Conv2dReduceTranspose::Conv2dReduceTranspose(GraphObj *graph, Tensor input,
|
||||
Tensor bias, Tensor output,
|
||||
bool PReLU_, int ph_, int pw_,
|
||||
int sh_, int sw_, int dh_, int dw_)
|
||||
bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_,
|
||||
int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::ConvTranspose2dReduce, input, bias, output,
|
||||
PReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
PReLU_, paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
IT_ASSERT(dh_ == 1);
|
||||
IT_ASSERT(dw_ == 1);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/bias2prelu.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/conv2dreduce.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
@ -16,7 +17,7 @@ namespace infini {
|
|||
|
||||
Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
|
||||
Graph g = make_ref<GraphObj>(cuda);
|
||||
auto input = g->addTensor({batchSize, 1, 32, 32}); // NCHW
|
||||
auto input = g->addTensor({batchSize, 32, 32, 1}); // NCHW
|
||||
// auto input = g->addTensor({16, 32, 32, 1}); // change to NHWC format
|
||||
vector<tuple<string, int, int, bool>> configs = {
|
||||
{"Conv", 56, 5, true}, {"Conv", 12, 1, true},
|
||||
|
@ -26,13 +27,17 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
|
|||
auto x = input;
|
||||
for (auto &[op, f, r, pRelu] : configs) {
|
||||
if (r == 5 && op == "Conv") { // for the first conv
|
||||
auto w = g->addTensor({f, x->getDims()[1], r, r});
|
||||
x = g->addOp<ConvObj>(x, w, nullptr, r / 2, r / 2)->getOutput();
|
||||
if (pRelu) {
|
||||
// TODO: Conv_nhwc + Bias+PRelu
|
||||
// Alternative: Conv_nchw + Transpose(NCHW->NHWC)+Bias+PRelu
|
||||
x = g->addOp<ReluObj>(x, nullptr)->getOutput();
|
||||
}
|
||||
auto w = g->addTensor({f, x->getDims()[3], r, r});
|
||||
x = g->addOp<ConvObj>(x, w, nullptr, r / 2, r / 2, 1, 1, 1, 1,
|
||||
nullptr, ActType::None, true)
|
||||
->getOutput();
|
||||
// if (pRelu) {
|
||||
// // TODO: Conv_nhwc + Bias+PRelu
|
||||
// // Alternative: Conv_nchw + Transpose(NCHW->NHWC)+Bias+PRelu
|
||||
// x = g->addOp<ReluObj>(x, nullptr)->getOutput();
|
||||
// }
|
||||
auto bias = g->addTensor({x->getDims()[3]});
|
||||
x = g->addOp<BiasPReLU>(x, bias, nullptr, pRelu, 0.1)->getOutput();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -46,7 +51,8 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
|
|||
->getOutput();
|
||||
auto bias = g->addTensor({f});
|
||||
if (op == "Conv") {
|
||||
x = g->addOp<Conv2dReduce>(x, bias, nullptr, pRelu, r / 2, r / 2)
|
||||
x = g->addOp<Conv2dReduce>(x, bias, nullptr, pRelu, 0.1, r / 2,
|
||||
r / 2)
|
||||
->getOutput();
|
||||
} else if (op == "ConvTranposed") {
|
||||
IT_ASSERT(r == 9);
|
||||
|
@ -54,8 +60,8 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
|
|||
// 1,
|
||||
// 1, 1)
|
||||
// ->getOutput();
|
||||
x = g->addOp<Conv2dReduceTranspose>(x, bias, nullptr, pRelu, 3, 3,
|
||||
4, 4)
|
||||
x = g->addOp<Conv2dReduceTranspose>(x, bias, nullptr, pRelu, 0.1, 3,
|
||||
3, 4, 4)
|
||||
->getOutput();
|
||||
} else
|
||||
IT_ASSERT(false);
|
||||
|
@ -69,7 +75,7 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
|
|||
|
||||
TEST(Case, fsrcnn_direct_run) {
|
||||
auto cuda = make_ref<CudaRuntimeObj>();
|
||||
auto g = createGraph(cuda, 16);
|
||||
auto g = createGraph(cuda, 1);
|
||||
cudaProfilerStart();
|
||||
printf("E2E time %.3lf\n",
|
||||
timeit([&]() { cuda->runWithoutSync(g); }, [&]() { cuda->sync(); }));
|
||||
|
@ -78,7 +84,7 @@ TEST(Case, fsrcnn_direct_run) {
|
|||
|
||||
TEST(Case, fsrcnn_cuda_graph) {
|
||||
auto cuda = make_ref<CudaRuntimeObj>();
|
||||
auto g = createGraph(cuda, 16);
|
||||
auto g = createGraph(cuda, 11);
|
||||
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
|
|
|
@ -40,12 +40,14 @@ void testConv2dReduce(
|
|||
x = gCuda->addOp<ReshapeObj>(x, nullptr, Shape{n, h, w, f, r, r})
|
||||
->getOutput();
|
||||
if (mode == "conv") {
|
||||
x = gCuda->addOp<Conv2dReduce>(x, b0Cuda, nullptr, false, r / 2, r / 2)
|
||||
x = gCuda
|
||||
->addOp<Conv2dReduce>(x, b0Cuda, nullptr, false, 0.1, r / 2,
|
||||
r / 2)
|
||||
->getOutput();
|
||||
} else {
|
||||
x = gCuda
|
||||
->addOp<Conv2dReduceTranspose>(x, b0Cuda, nullptr, false, r / 2,
|
||||
r / 2, 2, 2)
|
||||
->addOp<Conv2dReduceTranspose>(x, b0Cuda, nullptr, false, 0.1,
|
||||
r / 2, r / 2, 2, 2)
|
||||
->getOutput();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue