convNHWC+pReLU+biasPReLU

This commit is contained in:
huangshuhong 2022-10-18 14:01:45 +08:00 committed by Liyan Zheng
parent 74e998e262
commit 36755c3160
16 changed files with 262 additions and 70 deletions

View File

@ -42,7 +42,8 @@ enum class OpType {
MemBound = 300,
//
Conv2dReduce = 400,
ConvTranspose2dReduce
ConvTranspose2dReduce,
BiasPReLU
};
using KernelAttrs = std::tuple<Device, OpType, DataType>;
@ -91,6 +92,7 @@ class OpRegistry {
FOP(MemBound);
FOP(Conv2dReduce);
FOP(ConvTranspose2dReduce);
FOP(BiasPReLU);
default:
IT_ASSERT(false);
break;

View File

@ -0,0 +1,9 @@
#pragma once
#include "operators/bias2prelu.h"
namespace infini {
void bias2prelu_kernel(float *input, float *bias, float *output, int n, int h,
int w, int c, bool PReLU, float paramPReLU);
}

View File

@ -4,12 +4,13 @@
namespace infini {
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
int n, int h, int w, int f, int r, int s, int oh,
int ow, int ph, int pw, int sh, int sw, int dh,
int dw);
float paramReLU, int n, int h, int w, int f, int r,
int s, int oh, int ow, int ph, int pw, int sh, int sw,
int dh, int dw);
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
bool PReLU, int n, int h, int w, int f, int r,
int s, int oh, int ow, int ph, int pw, int sh,
int sw, int dh, int dw);
bool PReLU, float paramReLU, int n, int h,
int w, int f, int r, int s, int oh, int ow,
int ph, int pw, int sh, int sw, int dh,
int dw);
} // namespace infini

View File

@ -0,0 +1,32 @@
#pragma once
#include "core/operator.h"
namespace infini {
class BiasPReLU : public OperatorObj {
protected:
bool PReLU;
float paramPReLU;
int n, h, w, c;
public:
BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
bool PReLU_, float paramPReLU_);
std::string toString() const override { return "Bias2PReLU"; }
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
bool getPReLU() const { return PReLU; }
float getParamReLU() const { return paramPReLU; }
Tensor getBias() const { return inputs[1]; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -67,19 +67,23 @@ class ConvBaseObj : public OperatorObj {
class ConvObj : public ConvBaseObj {
private:
ActType act;
bool NHWC_layout;
public:
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
Tensor bias = nullptr, ActType act = ActType::None);
Tensor bias = nullptr, ActType act = ActType::None,
bool nhwc = false);
// Constructors for setting padding mode
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
int dh = 1, int dw = 1, Tensor bias = nullptr,
ActType act = ActType::None);
ActType act = ActType::None, bool nhwc = false);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }
bool getNHWCLayout() const { return NHWC_layout; }
int getNumGroups() const override { return c / getChannelPerGroup(); }
private:

View File

@ -10,11 +10,12 @@ class Conv2dReduceBase : public OperatorObj {
int dh, dw;
int n, h, w, f, r, s; // c has been reduced
bool PReLU;
float paramReLU;
public:
Conv2dReduceBase(OpType opType, Tensor input, Tensor bias, Tensor output,
bool PReLU_, int ph_, int pw_, int sh_ = 1, int sw_ = 1,
int dh_ = 1, int dw_ = 1);
bool PReLU_, float paramReLU_, int ph_, int pw_,
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
std::string toString() const override;
int numInputs() const override { return 2; }
@ -27,6 +28,7 @@ class Conv2dReduceBase : public OperatorObj {
int getSh() const { return sh; }
int getSw() const { return sw; }
bool getPReLU() const { return PReLU; }
float getParamReLU() const { return paramReLU; }
Tensor getBias() const { return inputs[1]; }
@ -41,16 +43,17 @@ class Conv2dReduceBase : public OperatorObj {
class Conv2dReduce : public Conv2dReduceBase {
public:
Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
bool PReLU_, int ph_, int pw_, int sh_ = 1, int sw_ = 1,
int dh_ = 1, int dw_ = 1);
bool PReLU_, float paramReLU_, int ph_, int pw_, int sh_ = 1,
int sw_ = 1, int dh_ = 1, int dw_ = 1);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
};
class Conv2dReduceTranspose : public Conv2dReduceBase {
public:
Conv2dReduceTranspose(GraphObj *graph, Tensor input, Tensor bias,
Tensor output, bool PReLU_, int ph_, int pw_,
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
Tensor output, bool PReLU_, float paramReLU_, int ph_,
int pw_, int sh_ = 1, int sw_ = 1, int dh_ = 1,
int dw_ = 1);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
};
} // namespace infini

View File

@ -0,0 +1,24 @@
#include "operators/bias2prelu.h"
#include "cuda/cuda_bias2prelu.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class Bias2PReluCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, const RuntimeObj *_context) const {
auto op = as<BiasPReLU>(_op);
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
float *const bias = (op->getInputs(1)->getRawDataPtr<float *>());
float *const output = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], h = dim[1], w = dim[2], c = dim[3];
bias2prelu_kernel(input, output, bias, n, h, w, c, op->getPReLU(),
op->getParamReLU());
}
};
REGISTER_KERNEL(Device::CUDA, OpType::BiasPReLU, DataType::Float32,
Bias2PReluCuda, "Bias2PReLU_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,34 @@
#include "cuda/cuda_common.h"
__global__ void bias2prelu_kernel_(
float *__restrict__ input,
float *__restrict__ bias,
float *__restrict__ output,
const bool PReLU, const float paramReLU,
const int n, const int h, const int w, const int c)
{
int nid = blockIdx.x, hid = blockIdx.y;
int wid = threadIdx.x, cid = threadIdx.y;
int offset = nid * h * w * c + hid * w * c + wid * c + cid;
float imm = bias[cid] + input[offset];
if (PReLU) {
imm = (imm > 0.0) ? imm : paramReLU * paramReLU;
}
output[offset] = imm;
}
namespace infini {
void bias2prelu_kernel(float *input, float *bias, float *output,
int n, int h, int w, int c, bool PReLU, float paramPReLU)
{
dim3 grid(n, h);
dim3 block(w, c);
cudaStream_t stream(cudaStreamPerThread);
bias2prelu_kernel_<<<grid, block, 0, stream>>>(input, bias, output,
PReLU, paramPReLU, n, h, w, c);
}
}

View File

@ -65,6 +65,9 @@ class convCudnn : public Kernel {
const int cpg = op->getChannelPerGroup();
const int g = c / cpg;
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
const auto nhwcLayout = op->getNHWCLayout();
const auto inoutLayout =
nhwcLayout ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
int channelsPerGrp = cpg, channels = c;
@ -72,7 +75,7 @@ class convCudnn : public Kernel {
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
inDesc, inoutLayout, CUDNN_DATA_FLOAT, n, channels, h, w));
// get kernels
cudnnFilterDescriptor_t knDesc;
@ -125,12 +128,17 @@ class convCudnn : public Kernel {
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, outn, outc,
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
checkCudnnError(cudnnSetTensor4dDescriptor(
outDesc, inoutLayout, CUDNN_DATA_FLOAT, outn, outc, outh, outw));
if (nhwcLayout) {
IT_ASSERT((vector{outn, outh, outw, outc}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
} else {
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
}
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc);

View File

@ -23,15 +23,17 @@ class Conv2dReduceCuda : public CudaKernelWithoutConfig {
auto odim = op->getOutput()->getDims();
int oh = odim[1], ow = odim[2];
bool PReLU = op->getPReLU();
float paramReLU = op->getParamReLU();
auto opType = op->getOpType();
if (opType == OpType::Conv2dReduce) {
conv2dreduce_kernel(input, bias, output, PReLU, n, h, w, f, r, s,
oh, ow, ph, pw, sh, sw, dh, dw);
conv2dreduce_kernel(input, bias, output, PReLU, paramReLU, n, h, w,
f, r, s, oh, ow, ph, pw, sh, sw, dh, dw);
} else {
convTranspose2dreduce_kernel(input, bias, output, PReLU, n, h, w, f,
r, s, oh, ow, ph, pw, sh, sw, dh, dw);
convTranspose2dreduce_kernel(input, bias, output, PReLU, paramReLU,
n, h, w, f, r, s, oh, ow, ph, pw, sh,
sw, dh, dw);
}
}
};

View File

@ -4,7 +4,7 @@ __global__ void conv2dreduce_kernel_(
float *__restrict__ input,
float *__restrict__ bias,
float *__restrict__ output,
const bool PReLU, const int n, const int f, const int h, const int w,
const bool PReLU, const float paramReLU, const int n, const int f, const int h, const int w,
const int oh, const int ow, const int r, const int s, const int ph,
const int pw, const int dh, const int dw, const int sh, const int sw)
{
@ -31,7 +31,7 @@ __global__ void conv2dreduce_kernel_(
imm += bias[fid];
}
if (PReLU) {
imm = imm > 0.0 ? imm : 0.0;
imm = imm > 0.0 ? imm : paramReLU * imm;
}
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
}
@ -41,7 +41,7 @@ __global__ void convTranspose2dreduce_kernel_(
float *__restrict__ input,
float *__restrict__ bias,
float *__restrict__ output,
const bool PReLU, const int n, const int f, const int h, const int w,
const bool PReLU, const float paramReLU, const int n, const int f, const int h, const int w,
const int oh, const int ow, const int r, const int s, const int ph,
const int pw, const int dh, const int dw, const int sh, const int sw)
{
@ -73,7 +73,7 @@ __global__ void convTranspose2dreduce_kernel_(
imm += bias[fid];
}
if (PReLU) {
imm = imm > 0.0 ? imm : 0.0;
imm = imm > 0.0 ? imm : paramReLU * imm;
}
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
}
@ -82,26 +82,26 @@ __global__ void convTranspose2dreduce_kernel_(
namespace infini {
void conv2dreduce_kernel(float *input, float *bias, float *output,
bool PReLU, int n, int h, int w, int f,
bool PReLU, float paramReLU, int n, int h, int w, int f,
int r, int s, int oh, int ow, int ph, int pw, int sh,
int sw, int dh, int dw)
{
dim3 grid(n, f);
dim3 block(oh, ow);
cudaStream_t stream(cudaStreamPerThread);
conv2dreduce_kernel_<<<grid, block, 0, stream>>>(input, bias, output, PReLU, n, f, h, w,
conv2dreduce_kernel_<<<grid, block, 0, stream>>>(input, bias, output, PReLU, paramReLU, n, f, h, w,
oh, ow, r, s, ph, pw, dh, dw, sh, sw);
}
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
bool PReLU, int n, int h, int w, int f,
bool PReLU, float paramReLU, int n, int h, int w, int f,
int r, int s, int oh, int ow, int ph, int pw, int sh,
int sw, int dh, int dw)
{
dim3 grid(n, f);
dim3 block(oh, ow);
cudaStream_t stream(cudaStreamPerThread);
convTranspose2dreduce_kernel_<<<grid, block, 0,stream>>>(input, bias, output, PReLU, n, f, h, w,
convTranspose2dreduce_kernel_<<<grid, block, 0,stream>>>(input, bias, output, PReLU, paramReLU, n, f, h, w,
oh, ow, r, s, ph, pw, dh, dw, sh, sw);
}
}

View File

@ -0,0 +1,37 @@
#include "operators/bias2prelu.h"
namespace infini {
BiasPReLU::BiasPReLU(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
bool PReLU_, float paramPReLU_)
: OperatorObj(OpType::BiasPReLU, {input, bias}, {output}), PReLU(PReLU_),
paramPReLU(paramPReLU_) {
auto dims = input->getDims();
n = dims[0], h = dims[1], w = dims[2], c = dims[3];
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> BiasPReLU::inferShape(const TensorVec &inputs) const {
const Tensor &input = inputs[0];
const Tensor &bias = inputs[1];
auto dims = input->getDims();
int n = dims[0], h = dims[1], w = dims[2], c = dims[3];
int bc = bias->getDims()[0];
if (bc != c)
return {};
return {{{n, h, w, c}}};
}
vector<int> BiasPReLU::getWorkloadVector() const {
return {enum_to_underlying(type), n, h, w, c};
}
vector<int> BiasPReLU::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
} // namespace infini

View File

@ -45,11 +45,25 @@ vector<int> ConvBaseObj::getOpAttrVector() const {
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
}
std::string ConvObj::toString() const {
std::string origin = ConvBaseObj::toString();
std::ostringstream os;
os << (NHWC_layout ? "NHWC" : "NCHW") << origin;
return os.str();
}
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
const Tensor &input = inputs[0];
const Tensor &weight = inputs[1];
n = input->getDims()[0], c = input->getDims()[1], h = input->getDims()[2],
w = input->getDims()[3], f = weight->getDims()[0], r = weight->getDims()[2],
n = input->getDims()[0];
if (NHWC_layout)
c = input->getDims()[3], h = input->getDims()[1],
w = input->getDims()[2];
else
c = input->getDims()[1], h = input->getDims()[2],
w = input->getDims()[3];
f = weight->getDims()[0], r = weight->getDims()[2],
s = weight->getDims()[3];
if (mode == PaddingMode::Same) {
int oh = h / sh;
@ -63,10 +77,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
ActType act, bool nhwc)
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
input, weight),
act(act) {
act(act), NHWC_layout(nhwc) {
if (bias)
IT_TODO_HALT();
setAuxilaryAttributes(PaddingMode::Other);
@ -75,10 +89,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
ActType act, bool nhwc)
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
input, weight),
act(act) {
act(act), NHWC_layout(nhwc) {
if (bias)
IT_TODO_HALT();
setAuxilaryAttributes(mode);
@ -87,16 +101,25 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
const auto &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0];
auto h = input->getDims()[2];
auto w = input->getDims()[3];
int n, h, w, ic, wc;
n = input->getDims()[0];
if (NHWC_layout) {
h = input->getDims()[1];
w = input->getDims()[2];
ic = input->getDims()[3];
} else {
h = input->getDims()[2];
w = input->getDims()[3];
ic = input->getDims()[1];
}
wc = weight->getDims()[1];
auto f = weight->getDims()[0];
auto r = weight->getDims()[2];
auto s = weight->getDims()[3];
int on = n, oc = f;
int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight
if (input->getDims()[1] % weight->getDims()[1] != 0)
if (ic % wc != 0)
return {};
// Set padding size
if (padding == PaddingMode::Other) {
@ -116,6 +139,9 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
oh = (h - (r - sh) * dh + ph * 2) / sh;
ow = (w - (s - sw) * dw + pw * 2) / sw;
}
if (NHWC_layout)
return {{{on, oh, ow, oc}}};
return {{{on, oc, oh, ow}}};
}

View File

@ -3,10 +3,11 @@
namespace infini {
Conv2dReduceBase::Conv2dReduceBase(OpType opType, Tensor input, Tensor bias,
Tensor output, bool PReLU_, int ph_, int pw_,
int sh_, int sw_, int dh_, int dw_)
Tensor output, bool PReLU_, float paramReLU_,
int ph_, int pw_, int sh_, int sw_, int dh_,
int dw_)
: OperatorObj(opType, {input, bias}, {output}), ph(ph_), pw(pw_), sh(sh_),
sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_) {
sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_), paramReLU(paramReLU_) {
// expect input shape is (n, h, w, f, r, s)
auto inputShape = input->getDims();
IT_ASSERT(inputShape.size() == 6);
@ -54,10 +55,10 @@ std::vector<int> Conv2dReduceBase::getOpAttrVector() const {
}
Conv2dReduce::Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias,
Tensor output, bool PReLU_, int ph_, int pw_,
int sh_, int sw_, int dh_, int dw_)
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_, ph_,
pw_, sh_, sw_, dh_, dw_) {
Tensor output, bool PReLU_, float paramReLU_,
int ph_, int pw_, int sh_, int sw_, int dh_, int dw_)
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_,
paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
IT_ASSERT(checkValid(graph));
}
@ -73,10 +74,11 @@ Conv2dReduce::inferShape(const TensorVec &inputs) const {
Conv2dReduceTranspose::Conv2dReduceTranspose(GraphObj *graph, Tensor input,
Tensor bias, Tensor output,
bool PReLU_, int ph_, int pw_,
int sh_, int sw_, int dh_, int dw_)
bool PReLU_, float paramReLU_,
int ph_, int pw_, int sh_, int sw_,
int dh_, int dw_)
: Conv2dReduceBase(OpType::ConvTranspose2dReduce, input, bias, output,
PReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
PReLU_, paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
IT_ASSERT(dh_ == 1);
IT_ASSERT(dw_ == 1);
IT_ASSERT(checkValid(graph));

View File

@ -3,6 +3,7 @@
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/bias2prelu.h"
#include "operators/conv.h"
#include "operators/conv2dreduce.h"
#include "operators/element_wise.h"
@ -16,7 +17,7 @@ namespace infini {
Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
Graph g = make_ref<GraphObj>(cuda);
auto input = g->addTensor({batchSize, 1, 32, 32}); // NCHW
auto input = g->addTensor({batchSize, 32, 32, 1}); // NCHW
// auto input = g->addTensor({16, 32, 32, 1}); // change to NHWC format
vector<tuple<string, int, int, bool>> configs = {
{"Conv", 56, 5, true}, {"Conv", 12, 1, true},
@ -26,13 +27,17 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
auto x = input;
for (auto &[op, f, r, pRelu] : configs) {
if (r == 5 && op == "Conv") { // for the first conv
auto w = g->addTensor({f, x->getDims()[1], r, r});
x = g->addOp<ConvObj>(x, w, nullptr, r / 2, r / 2)->getOutput();
if (pRelu) {
// TODO: Conv_nhwc + Bias+PRelu
// Alternative: Conv_nchw + Transpose(NCHW->NHWC)+Bias+PRelu
x = g->addOp<ReluObj>(x, nullptr)->getOutput();
}
auto w = g->addTensor({f, x->getDims()[3], r, r});
x = g->addOp<ConvObj>(x, w, nullptr, r / 2, r / 2, 1, 1, 1, 1,
nullptr, ActType::None, true)
->getOutput();
// if (pRelu) {
// // TODO: Conv_nhwc + Bias+PRelu
// // Alternative: Conv_nchw + Transpose(NCHW->NHWC)+Bias+PRelu
// x = g->addOp<ReluObj>(x, nullptr)->getOutput();
// }
auto bias = g->addTensor({x->getDims()[3]});
x = g->addOp<BiasPReLU>(x, bias, nullptr, pRelu, 0.1)->getOutput();
continue;
}
@ -46,7 +51,8 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
->getOutput();
auto bias = g->addTensor({f});
if (op == "Conv") {
x = g->addOp<Conv2dReduce>(x, bias, nullptr, pRelu, r / 2, r / 2)
x = g->addOp<Conv2dReduce>(x, bias, nullptr, pRelu, 0.1, r / 2,
r / 2)
->getOutput();
} else if (op == "ConvTranposed") {
IT_ASSERT(r == 9);
@ -54,8 +60,8 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
// 1,
// 1, 1)
// ->getOutput();
x = g->addOp<Conv2dReduceTranspose>(x, bias, nullptr, pRelu, 3, 3,
4, 4)
x = g->addOp<Conv2dReduceTranspose>(x, bias, nullptr, pRelu, 0.1, 3,
3, 4, 4)
->getOutput();
} else
IT_ASSERT(false);
@ -69,7 +75,7 @@ Graph createGraph(Ref<CudaRuntimeObj> cuda, int batchSize) {
TEST(Case, fsrcnn_direct_run) {
auto cuda = make_ref<CudaRuntimeObj>();
auto g = createGraph(cuda, 16);
auto g = createGraph(cuda, 1);
cudaProfilerStart();
printf("E2E time %.3lf\n",
timeit([&]() { cuda->runWithoutSync(g); }, [&]() { cuda->sync(); }));
@ -78,7 +84,7 @@ TEST(Case, fsrcnn_direct_run) {
TEST(Case, fsrcnn_cuda_graph) {
auto cuda = make_ref<CudaRuntimeObj>();
auto g = createGraph(cuda, 16);
auto g = createGraph(cuda, 11);
cudaGraph_t graph;
cudaGraphExec_t instance;

View File

@ -40,12 +40,14 @@ void testConv2dReduce(
x = gCuda->addOp<ReshapeObj>(x, nullptr, Shape{n, h, w, f, r, r})
->getOutput();
if (mode == "conv") {
x = gCuda->addOp<Conv2dReduce>(x, b0Cuda, nullptr, false, r / 2, r / 2)
x = gCuda
->addOp<Conv2dReduce>(x, b0Cuda, nullptr, false, 0.1, r / 2,
r / 2)
->getOutput();
} else {
x = gCuda
->addOp<Conv2dReduceTranspose>(x, b0Cuda, nullptr, false, r / 2,
r / 2, 2, 2)
->addOp<Conv2dReduceTranspose>(x, b0Cuda, nullptr, false, 0.1,
r / 2, r / 2, 2, 2)
->getOutput();
}