forked from jiuyuan/InfiniTensor
add ConvNHWC and FSRCNN graph
This commit is contained in:
parent
225a42f22d
commit
ff97c879fb
|
@ -11,6 +11,7 @@ enum class OpType {
|
|||
Matmul,
|
||||
ConvTrans,
|
||||
ConvTransNHWC,
|
||||
ConvNHWC,
|
||||
G2BMM,
|
||||
GBMM,
|
||||
Pad,
|
||||
|
@ -122,6 +123,7 @@ class OpRegistry {
|
|||
FOP(Matmul);
|
||||
FOP(ConvTrans);
|
||||
FOP(ConvTransNHWC);
|
||||
FOP(ConvNHWC);
|
||||
FOP(G2BMM);
|
||||
FOP(GBMM);
|
||||
FOP(Pad);
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
namespace infini {
|
||||
|
||||
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
|
||||
Graph getFSRCNNGraph(int batch, Runtime runtime);
|
||||
vector<Tensor> runInfoGAN(int nLayers);
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
||||
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
||||
|
|
|
@ -111,7 +111,7 @@ class ConvBaseObj : public OperatorObj {
|
|||
auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); }
|
||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
int getChannelPerGroup() const {
|
||||
if (type == OpType::ConvTransNHWC) {
|
||||
if (type == OpType::ConvTransNHWC || type == OpType::ConvNHWC) {
|
||||
return inputs[1]->getDims()[3];
|
||||
} else {
|
||||
return inputs[1]->getDims()[1];
|
||||
|
@ -149,6 +149,25 @@ class ConvObj : public ConvBaseObj {
|
|||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
class ConvNHWCObj : public ConvBaseObj {
|
||||
public:
|
||||
ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
// Constructors for setting padding mode
|
||||
ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
OP_CLONE(ConvNHWCObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
class ConvBackwardFilterObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
|
@ -220,6 +239,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj {
|
|||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
int getNumGroups() const override { return group; }
|
||||
std::pair<int, int> getOutputPadding() const { return {oph, opw}; }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
|
|
|
@ -709,7 +709,7 @@ class OnnxStub:
|
|||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Conv:
|
||||
if ty == backend.OpType.Conv or ty == backend.OpType.ConvNHWC:
|
||||
ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
|
@ -723,7 +723,7 @@ class OnnxStub:
|
|||
group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1],
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.ConvTrans:
|
||||
elif ty == backend.OpType.ConvTrans or ty == backend.OpType.ConvTransNHWC:
|
||||
ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
|
|
|
@ -69,6 +69,7 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Matmul)
|
||||
.VALUE(OpType, ConvTrans)
|
||||
.VALUE(OpType, ConvTransNHWC)
|
||||
.VALUE(OpType, ConvNHWC)
|
||||
.VALUE(OpType, G2BMM)
|
||||
.VALUE(OpType, GBMM)
|
||||
.VALUE(OpType, Pad)
|
||||
|
@ -143,17 +144,28 @@ static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
|
|||
#endif
|
||||
|
||||
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv);
|
||||
auto conv = dynamic_cast<const ConvObj *>(op.get());
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv || op->getOpType() == OpType::ConvNHWC);
|
||||
auto conv = dynamic_cast<const ConvBaseObj *>(op.get());
|
||||
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||
conv->getDw(), conv->getSh(), conv->getSw());
|
||||
}
|
||||
|
||||
static std::tuple<int, int, int, int, int, int, int, int>
|
||||
conv_trans_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ConvTrans);
|
||||
auto conv = dynamic_cast<const ConvTransposed2dObj *>(op.get());
|
||||
auto [oph, opw] = conv->getOutputPadding();
|
||||
IT_ASSERT(op->getOpType() == OpType::ConvTrans || op->getOpType() == OpType::ConvTransNHWC);
|
||||
auto conv = dynamic_cast<const ConvBaseObj *>(op.get());
|
||||
int oph, opw;
|
||||
|
||||
if (op->getOpType() == OpType::ConvTrans) {
|
||||
auto _conv = dynamic_cast<const ConvTransposed2dObj *>(op.get());
|
||||
auto output_pad = _conv->getOutputPadding();
|
||||
oph = output_pad.first; opw = output_pad.second;
|
||||
} else {
|
||||
auto _conv = dynamic_cast<const ConvTransposed2dNHWCObj *>(op.get());
|
||||
auto output_pad = _conv->getOutputPadding();
|
||||
oph = output_pad.first; opw = output_pad.second;
|
||||
}
|
||||
|
||||
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||
conv->getDw(), conv->getSh(), conv->getSw(), oph,
|
||||
opw);
|
||||
|
@ -385,6 +397,7 @@ void export_test_model(py::module &m) {
|
|||
#ifdef USE_CUDA
|
||||
m.def("runInfoGAN", &runInfoGAN)
|
||||
.def("getGANGraph", &getGANGraph)
|
||||
.def("getFSRCNNGraph", &getFSRCNNGraph)
|
||||
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
|
||||
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
||||
"tuning"_a = false, "mode"_a = NMutator::Mode::Normal,
|
||||
|
|
|
@ -52,7 +52,7 @@ class convCudnn : public Kernel {
|
|||
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
cudnnTensorDescriptor_t>
|
||||
createCuDNNDescriptor(const Ref<ConvObj> &op,
|
||||
createCuDNNDescriptor(const Ref<ConvBaseObj> &op,
|
||||
const ConvCuDnnPerfRecord &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
|
@ -68,23 +68,28 @@ class convCudnn : public Kernel {
|
|||
|
||||
int channelsPerGrp = cpg, channels = c;
|
||||
|
||||
// set input format
|
||||
cudnnTensorFormat_t tensorFormat =
|
||||
(op->getOpType() == OpType::ConvNHWC) ? CUDNN_TENSOR_NHWC
|
||||
: CUDNN_TENSOR_NCHW;
|
||||
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
inDesc, tensorFormat, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||
CUDNN_TENSOR_NCHW, f,
|
||||
tensorFormat, f,
|
||||
channelsPerGrp, r, s));
|
||||
// get bias
|
||||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
biasDesc, tensorFormat, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
|
||||
// get convlution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
|
@ -125,18 +130,25 @@ class convCudnn : public Kernel {
|
|||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, tensorFormat,
|
||||
CUDNN_DATA_FLOAT, outn, outc,
|
||||
outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
|
||||
if (op->getOpType() == OpType::ConvNHWC) {
|
||||
IT_ASSERT((vector{outn, outh, outw, outc}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
} else {
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
}
|
||||
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
}
|
||||
|
||||
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
||||
bool cuDNNUnfused(const Ref<ConvBaseObj> &op, const ConvCuDnnPerfRecord &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
||||
|
@ -220,11 +232,12 @@ class convCudnn : public Kernel {
|
|||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
int try_algo = op->getOpType() == OpType::ConvNHWC ? 2 : N_ALGO;
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
for (int algo = 0; algo < try_algo; algo++) {
|
||||
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
|
||||
auto &record = *recordRef;
|
||||
record.mode = mode;
|
||||
|
@ -283,7 +296,7 @@ class convCudnn : public Kernel {
|
|||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
auto record = as<ConvCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, record, context);
|
||||
|
@ -294,5 +307,8 @@ class convCudnn : public Kernel {
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn,
|
||||
"Conv_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvNHWC, DataType::Float32, convCudnn,
|
||||
"ConvNHWC_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json);
|
||||
} // namespace infini
|
||||
|
|
|
@ -79,6 +79,58 @@ Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId) {
|
|||
return g;
|
||||
}
|
||||
|
||||
// NHWC
|
||||
Graph getFSRCNNGraph(int batch, Runtime runtime) {
|
||||
// n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU
|
||||
const DetailedConfigs fsrcnn_config = {
|
||||
{batch, 1, 32, 32, 56, 5, 5, 1, 2, 1, true},
|
||||
{batch, 56, 32, 32, 12, 1, 1, 1, 0, 1, true},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, true},
|
||||
{batch, 12, 32, 32, 56, 32, 32, 1, 0, 1, true},
|
||||
{batch, 56, 32, 32, 1, 9, 9, 4, 3, 1, false} // ConvTransNHWC
|
||||
// n, f, h, w, c, r, s, stride, pad, dilation, has_pReLU
|
||||
};
|
||||
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
Tensor input;
|
||||
{
|
||||
auto &[n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU] = fsrcnn_config[0];
|
||||
input = g->addTensor({batch, h, w, c}, DataType::Float32,
|
||||
TensorType::Input);
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)fsrcnn_config.size() - 1; ++i) {
|
||||
// auto [channel, kernelSize, pad, stride, tanh] = configs[i];
|
||||
auto &[n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU] = fsrcnn_config[i];
|
||||
IT_ASSERT(input->getDims()[3] == c);
|
||||
auto weight = g->addTensor({f, r, s, c}, DataType::Float32,
|
||||
TensorType::Initialized); // f, r, s, c
|
||||
input = g->addOp<ConvNHWCObj>(input, weight, nullptr, pad,
|
||||
pad, stride, stride, 1, 1)
|
||||
->getOutput();
|
||||
if (has_pReLU) {
|
||||
input = g->addOp<ReluObj>(input, nullptr)->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
// last operator is a ConvTransNHWC
|
||||
{
|
||||
auto &[n, f, h, w, c, r, s, stride, pad, dilation, has_pReLU] = fsrcnn_config[fsrcnn_config.size()-1];
|
||||
IT_ASSERT(input->getDims()[3] == f);
|
||||
auto weight = g->addTensor({f, r, s, c}, DataType::Float32,
|
||||
TensorType::Initialized); // f, r, s, c
|
||||
input = g->addOp<ConvTransposed2dNHWCObj>(input, weight, nullptr, pad,
|
||||
pad, stride, stride, 1, 1)
|
||||
->getOutput();
|
||||
}
|
||||
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId) {
|
||||
IT_ASSERT(0 <= layerId && layerId < 5);
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
|
|
@ -114,6 +114,75 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
|||
return {{{on, oc, oh, ow}}};
|
||||
}
|
||||
|
||||
void ConvNHWCObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], c = input->getDims()[3], h = input->getDims()[1],
|
||||
w = input->getDims()[2], f = weight->getDims()[0], r = weight->getDims()[1],
|
||||
s = weight->getDims()[2];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
int ow = w / sw;
|
||||
ph = (h - oh * sh + (r - sh) * dh) / 2;
|
||||
pw = (w - ow * sw + (s - sw) * dw) / 2;
|
||||
} else if (mode == PaddingMode::Valid) {
|
||||
ph = pw = 0;
|
||||
}
|
||||
}
|
||||
|
||||
ConvNHWCObj::ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::ConvNHWC, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ConvNHWCObj::ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::ConvNHWC, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ConvNHWCObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto h = input->getDims()[1];
|
||||
auto w = input->getDims()[2];
|
||||
auto f = weight->getDims()[0];
|
||||
auto r = weight->getDims()[1];
|
||||
auto s = weight->getDims()[2];
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
if (input->getDims()[3] % weight->getDims()[3] != 0)
|
||||
return {};
|
||||
// Set padding size
|
||||
if (padding == PaddingMode::Other) {
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
} else if (padding == PaddingMode::Same) {
|
||||
oh = h / sh;
|
||||
ow = w / sw;
|
||||
// ph = (h - oh * sh + (r - sh) * dh) / 2;
|
||||
// pw = (w - ow * sw + (s - sw) * dw) / 2;
|
||||
} else if (padding == PaddingMode::Valid) {
|
||||
int ph = 0;
|
||||
int pw = 0;
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
}
|
||||
return {{{on, oh, ow, oc}}};
|
||||
}
|
||||
|
||||
ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
||||
Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw,
|
||||
|
|
|
@ -43,6 +43,42 @@ void testConvCudnn(
|
|||
gCuda->print();
|
||||
}
|
||||
|
||||
void testConvNHWCCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 4, 4, 3}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvNHWCObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
o0Cpu->print();
|
||||
o0Cpu->printData();
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, run) {
|
||||
testConvCudnn(OneGenerator(),
|
||||
vector<float>{12, 12, 18, 18, 12, 12, 18, 18});
|
||||
|
@ -51,6 +87,14 @@ TEST(cuDNN_Conv, run) {
|
|||
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, runNHWC) {
|
||||
testConvNHWCCudnn(OneGenerator(),
|
||||
vector<float>{12., 12., 12., 12., 18., 18., 18., 18.});
|
||||
testConvNHWCCudnn(
|
||||
IncrementalGenerator(),
|
||||
vector<float>{3350, 7562, 2306, 5546, 9480, 24546, 7185, 20793});
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, tune) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
|
|
@ -111,7 +111,8 @@ if __name__ == "__main__":
|
|||
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3
|
||||
# ft.getGANGraph(batch, runtime, 5, 1)
|
||||
# construct_convTranspose2d(runtime)
|
||||
(load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
(ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16")
|
||||
]
|
||||
|
||||
for original_g, name in graphs:
|
||||
|
|
Loading…
Reference in New Issue