forked from jiuyuan/InfiniTensor
Merge branch 'NNET_e2e' into NNET_gcn
This commit is contained in:
commit
18d6ba4022
|
@ -1 +1 @@
|
|||
Subproject commit 3bb9240cb15459768adb3e7d963a20e1523a6294
|
||||
Subproject commit f30744bcf726ea3735df7ecf9e9de9ddac540283
|
|
@ -1 +1 @@
|
|||
Subproject commit b796f7d44681514f58a683a3a71ff17c94edb0c1
|
||||
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929
|
|
@ -1 +1 @@
|
|||
Subproject commit 13132dd361c8c5b5753983d5186cf54f689d90f9
|
||||
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756
|
|
@ -1 +1 @@
|
|||
Subproject commit 0bd8896a4010f2d91b2340570c24fa08606ec406
|
||||
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae
|
|
@ -57,6 +57,11 @@ class GraphHandlerObj {
|
|||
Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw, int oph,
|
||||
int opw);
|
||||
Tensor convNHWC(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw);
|
||||
Tensor convTransposed2dNHWC(Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw,
|
||||
int oph, int opw);
|
||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||
Tensor bias, ActType act);
|
||||
Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var,
|
||||
|
|
|
@ -11,6 +11,7 @@ enum class OpType {
|
|||
Matmul,
|
||||
ConvTrans,
|
||||
ConvTransNHWC,
|
||||
ConvNHWC,
|
||||
G2BMM,
|
||||
GBMM,
|
||||
Pad,
|
||||
|
@ -102,7 +103,10 @@ enum class OpType {
|
|||
Dropout,
|
||||
//
|
||||
MemBound = 300,
|
||||
Any,
|
||||
//
|
||||
Conv2dReduce = 400,
|
||||
Conv2dReduceTranspose,
|
||||
Any
|
||||
};
|
||||
|
||||
using KernelAttrs = std::tuple<Device, OpType, DataType>;
|
||||
|
@ -123,6 +127,7 @@ class OpRegistry {
|
|||
FOP(Matmul);
|
||||
FOP(ConvTrans);
|
||||
FOP(ConvTransNHWC);
|
||||
FOP(ConvNHWC);
|
||||
FOP(G2BMM);
|
||||
FOP(GBMM);
|
||||
FOP(Pad);
|
||||
|
@ -210,6 +215,9 @@ class OpRegistry {
|
|||
FOP(BitRightShift);
|
||||
//
|
||||
FOP(MemBound);
|
||||
//
|
||||
FOP(Conv2dReduce);
|
||||
FOP(Conv2dReduceTranspose);
|
||||
FOP(Any);
|
||||
default:
|
||||
IT_ASSERT(false, "Unknown OpType " +
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "nnet/dbg.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
|
|
@ -66,6 +66,9 @@ class NMutator : public Mutator {
|
|||
Graph transformGbmm(Operator op);
|
||||
Graph transformDialtedConv(Operator _op);
|
||||
Graph transformConv1xk(Operator op);
|
||||
// Graph transformConv1xk(Operator op);
|
||||
Graph transformConvToGEMMReduce(Operator _op);
|
||||
Graph transformConvTranposeToGEMMReduce(Operator _op);
|
||||
|
||||
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||
Tensor output = nullptr);
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
namespace infini {
|
||||
|
||||
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
|
||||
Graph getFSRCNNGraph(int batch, Runtime runtime);
|
||||
Graph getLongformer(Runtime runtime, int bs);
|
||||
vector<Tensor> runInfoGAN(int nLayers);
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
||||
|
|
|
@ -111,7 +111,7 @@ class ConvBaseObj : public OperatorObj {
|
|||
auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); }
|
||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
int getChannelPerGroup() const {
|
||||
if (type == OpType::ConvTransNHWC) {
|
||||
if (type == OpType::ConvTransNHWC || type == OpType::ConvNHWC) {
|
||||
return inputs[1]->getDims()[3];
|
||||
} else {
|
||||
return inputs[1]->getDims()[1];
|
||||
|
@ -149,6 +149,25 @@ class ConvObj : public ConvBaseObj {
|
|||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
class ConvNHWCObj : public ConvBaseObj {
|
||||
public:
|
||||
ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
// Constructors for setting padding mode
|
||||
ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
OP_CLONE(ConvNHWCObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
class ConvBackwardFilterObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
|
@ -220,6 +239,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj {
|
|||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
int getNumGroups() const override { return group; }
|
||||
std::pair<int, int> getOutputPadding() const { return {oph, opw}; }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class Conv2dReduceBase : public OperatorObj {
|
||||
protected:
|
||||
Tensor bias;
|
||||
int ph, pw;
|
||||
int sh, sw;
|
||||
int dh, dw;
|
||||
int n, h, w, f, r, s; // c has been reduced
|
||||
bool PReLU;
|
||||
float paramReLU;
|
||||
|
||||
public:
|
||||
Conv2dReduceBase(OpType opType, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramReLU_, int ph_, int pw_,
|
||||
int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
int getDh() const { return dh; }
|
||||
int getDw() const { return dw; }
|
||||
int getPh() const { return ph; }
|
||||
int getPw() const { return pw; }
|
||||
int getSh() const { return sh; }
|
||||
int getSw() const { return sw; }
|
||||
bool getPReLU() const { return PReLU; }
|
||||
float getParamReLU() const { return paramReLU; }
|
||||
|
||||
Tensor getBias() const { return bias; }
|
||||
|
||||
// optional<vector<Shape>> inferShape(const TensorVec &inputs) const
|
||||
// override;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class Conv2dReduce : public Conv2dReduceBase {
|
||||
public:
|
||||
Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramReLU_, int ph_, int pw_, int sh_ = 1,
|
||||
int sw_ = 1, int dh_ = 1, int dw_ = 1);
|
||||
OP_CLONE(Conv2dReduce);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
};
|
||||
|
||||
class Conv2dReduceTranspose : public Conv2dReduceBase {
|
||||
public:
|
||||
Conv2dReduceTranspose(GraphObj *graph, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, float paramReLU_, int ph_,
|
||||
int pw_, int sh_ = 1, int sw_ = 1, int dh_ = 1,
|
||||
int dw_ = 1);
|
||||
OP_CLONE(Conv2dReduceTranspose);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -709,7 +709,7 @@ class OnnxStub:
|
|||
ctx.push_output(f"{name}_{i}_{it.guid()}", it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Conv:
|
||||
if ty == backend.OpType.Conv or ty == backend.OpType.ConvNHWC:
|
||||
ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
|
@ -723,7 +723,7 @@ class OnnxStub:
|
|||
group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1],
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.ConvTrans:
|
||||
elif ty == backend.OpType.ConvTrans or ty == backend.OpType.ConvTransNHWC:
|
||||
ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
|
@ -895,6 +895,26 @@ class OnnxStub:
|
|||
domain="nnet",
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.Conv2dReduce:
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
domain="nnet",
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.Conv2dReduceTranspose:
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
domain="nnet",
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.MemBound:
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
|
|
|
@ -57,6 +57,38 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::convNHWC(Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh,
|
||||
int dw) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<ConvNHWCObj>(std::move(input), std::move(weight),
|
||||
output, ph, pw, sh, sw, dh, dw);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<ConvNHWCObj>(std::move(input), std::move(weight), output,
|
||||
ph, pw, sh, sw, dh, dw)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::convTransposed2dNHWC(Tensor input, Tensor weight,
|
||||
Tensor output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw,
|
||||
int oph, int opw) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<ConvTransposed2dNHWCObj>(
|
||||
std::move(input), std::move(weight), output, ph, pw, sh, sw, dh, dw,
|
||||
oph, opw);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<ConvTransposed2dNHWCObj>(std::move(input),
|
||||
std::move(weight), output, ph,
|
||||
pw, sh, sw, dh, dw, oph, opw)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||
bool transB, Tensor bias, ActType act) {
|
||||
if (y) {
|
||||
|
|
|
@ -26,7 +26,8 @@ bool OperatorObj::isConcatOp() const { return type == OpType::Concat; }
|
|||
bool OperatorObj::isComputeOp() const {
|
||||
return type == OpType::Conv || type == OpType::Matmul ||
|
||||
type == OpType::ConvTrans || type == OpType::ConvTransNHWC ||
|
||||
type == OpType::G2BMM || type == OpType::GBMM;
|
||||
type == OpType::G2BMM || type == OpType::GBMM ||
|
||||
type == OpType::ConvNHWC;
|
||||
}
|
||||
|
||||
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/search_engine.h"
|
||||
#include "core/hash.h"
|
||||
#include "core/runtime.h"
|
||||
#include "ffi/ffi_callback.h"
|
||||
#include "nnet/dbg.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -348,8 +349,8 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
|
|||
// // HACK: only try the first one for debug
|
||||
// if (mutatedGraphs.size() > 2)
|
||||
// mutatedGraphs.resize(2);
|
||||
if (mutatedGraphs.size() >= 2)
|
||||
mutatedGraphs = {mutatedGraphs[1]};
|
||||
// if (mutatedGraphs.size() >= 2)
|
||||
// mutatedGraphs = {mutatedGraphs[1]};
|
||||
for (auto graph : graphs) {
|
||||
for (auto mutatedGraph : mutatedGraphs) {
|
||||
std::vector<Operator> ops;
|
||||
|
@ -455,6 +456,9 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
}
|
||||
|
||||
double SearchEngine::getEstimatedGraphPerf(Graph graph) {
|
||||
// dbg(graph);
|
||||
// // hkz
|
||||
// callback::exportONNX(graph, "a.onnx");
|
||||
return runtimeExec->getPerfTime(graph, false, true, true);
|
||||
}
|
||||
|
||||
|
@ -502,6 +506,7 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
|
||||
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
|
||||
// Eliminate transpose and reshape operators
|
||||
// FIXME: current Relu only support 3D and 4D tensors
|
||||
if (auto eliminatedGraph = mutator->eliminateVertically(
|
||||
make_ref<GraphObj>(runtimeExec, chainOps)))
|
||||
bestGraph = eliminatedGraph;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#ifdef INFINI_USE_TVM
|
||||
|
|
|
@ -69,6 +69,7 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Matmul)
|
||||
.VALUE(OpType, ConvTrans)
|
||||
.VALUE(OpType, ConvTransNHWC)
|
||||
.VALUE(OpType, ConvNHWC)
|
||||
.VALUE(OpType, G2BMM)
|
||||
.VALUE(OpType, GBMM)
|
||||
.VALUE(OpType, Pad)
|
||||
|
@ -100,6 +101,8 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Abs)
|
||||
.VALUE(OpType, Resize)
|
||||
.VALUE(OpType, Dropout)
|
||||
.VALUE(OpType, Conv2dReduce)
|
||||
.VALUE(OpType, Conv2dReduceTranspose)
|
||||
.VALUE(OpType, MemBound)
|
||||
.VALUE(OpType, Any)
|
||||
.export_values();
|
||||
|
@ -144,17 +147,32 @@ static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
|
|||
#endif
|
||||
|
||||
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv);
|
||||
auto conv = dynamic_cast<const ConvObj *>(op.get());
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv ||
|
||||
op->getOpType() == OpType::ConvNHWC);
|
||||
auto conv = dynamic_cast<const ConvBaseObj *>(op.get());
|
||||
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||
conv->getDw(), conv->getSh(), conv->getSw());
|
||||
}
|
||||
|
||||
static std::tuple<int, int, int, int, int, int, int, int>
|
||||
conv_trans_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ConvTrans);
|
||||
auto conv = dynamic_cast<const ConvTransposed2dObj *>(op.get());
|
||||
auto [oph, opw] = conv->getOutputPadding();
|
||||
IT_ASSERT(op->getOpType() == OpType::ConvTrans ||
|
||||
op->getOpType() == OpType::ConvTransNHWC);
|
||||
auto conv = dynamic_cast<const ConvBaseObj *>(op.get());
|
||||
int oph, opw;
|
||||
|
||||
if (op->getOpType() == OpType::ConvTrans) {
|
||||
auto _conv = dynamic_cast<const ConvTransposed2dObj *>(op.get());
|
||||
auto output_pad = _conv->getOutputPadding();
|
||||
oph = output_pad.first;
|
||||
opw = output_pad.second;
|
||||
} else {
|
||||
auto _conv = dynamic_cast<const ConvTransposed2dNHWCObj *>(op.get());
|
||||
auto output_pad = _conv->getOutputPadding();
|
||||
oph = output_pad.first;
|
||||
opw = output_pad.second;
|
||||
}
|
||||
|
||||
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||
conv->getDw(), conv->getSh(), conv->getSw(), oph,
|
||||
opw);
|
||||
|
@ -328,6 +346,9 @@ void init_graph_builder(py::module &m) {
|
|||
"tensor_type"_a = TensorType::Other)
|
||||
.def("conv", &Handler::conv, policy::move)
|
||||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||
.def("convNHWC", &Handler::convNHWC, policy::move)
|
||||
.def("convtransposed2dNHWC", &Handler::convTransposed2dNHWC,
|
||||
policy::move)
|
||||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNorm", &Handler::batchNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
|
@ -386,6 +407,7 @@ void export_test_model(py::module &m) {
|
|||
#ifdef USE_CUDA
|
||||
m.def("runInfoGAN", &runInfoGAN)
|
||||
.def("getGANGraph", &getGANGraph)
|
||||
.def("getFSRCNNGraph", &getFSRCNNGraph)
|
||||
.def("getLongformer", &getLongformer)
|
||||
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
|
||||
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
||||
|
|
|
@ -52,7 +52,7 @@ class convCudnn : public Kernel {
|
|||
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
cudnnTensorDescriptor_t>
|
||||
createCuDNNDescriptor(const Ref<ConvObj> &op,
|
||||
createCuDNNDescriptor(const Ref<ConvBaseObj> &op,
|
||||
const ConvCuDnnPerfRecord &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
|
@ -68,15 +68,23 @@ class convCudnn : public Kernel {
|
|||
|
||||
int channelsPerGrp = cpg, channels = c;
|
||||
|
||||
// set input format
|
||||
cudnnTensorFormat_t tensorFormat = (op->getOpType() == OpType::ConvNHWC)
|
||||
? CUDNN_TENSOR_NHWC
|
||||
: CUDNN_TENSOR_NCHW;
|
||||
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
inDesc, tensorFormat, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
// FIXME: filter data layout is not changed with input data layout
|
||||
// since FCRS shows better performance for NHWC inputs in some cases.
|
||||
// This should be tunable.
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||
CUDNN_TENSOR_NCHW, f,
|
||||
channelsPerGrp, r, s));
|
||||
|
@ -84,7 +92,7 @@ class convCudnn : public Kernel {
|
|||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
biasDesc, tensorFormat, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
|
||||
// get convlution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
|
@ -125,18 +133,25 @@ class convCudnn : public Kernel {
|
|||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, outn, outc,
|
||||
outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, tensorFormat, CUDNN_DATA_FLOAT, outn, outc, outh, outw));
|
||||
|
||||
if (op->getOpType() == OpType::ConvNHWC) {
|
||||
IT_ASSERT((vector{outn, outh, outw, outc}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
} else {
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
}
|
||||
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
}
|
||||
|
||||
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
||||
bool cuDNNUnfused(const Ref<ConvBaseObj> &op,
|
||||
const ConvCuDnnPerfRecord &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
||||
|
@ -220,11 +235,12 @@ class convCudnn : public Kernel {
|
|||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
int try_algo = op->getOpType() == OpType::ConvNHWC ? 2 : N_ALGO;
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
for (int algo = 0; algo < try_algo; algo++) {
|
||||
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
|
||||
auto &record = *recordRef;
|
||||
record.mode = mode;
|
||||
|
@ -283,7 +299,7 @@ class convCudnn : public Kernel {
|
|||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
auto record = as<ConvCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, record, context);
|
||||
|
@ -294,5 +310,8 @@ class convCudnn : public Kernel {
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn,
|
||||
"Conv_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvNHWC, DataType::Float32, convCudnn,
|
||||
"ConvNHWC_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json);
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
#include "operators/conv2dreduce.h"
|
||||
#include "cuda/cuda_conv2dreduce.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class Conv2dReduceCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op, const RuntimeObj *_context) const {
|
||||
auto op = as<Conv2dReduceBase>(_op);
|
||||
float *const input = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bias =
|
||||
op->getBias() ? (op->getBias()->getRawDataPtr<float *>()) : nullptr;
|
||||
float *const output = (op->getOutput()->getRawDataPtr<float *>());
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int n = dim[0], h = dim[1], w = dim[2], f = dim[3], r = dim[4],
|
||||
s = dim[5];
|
||||
int dh = op->getDh(), dw = op->getDw();
|
||||
int sh = op->getSh(), sw = op->getSw();
|
||||
int ph = op->getPh(), pw = op->getPw();
|
||||
auto odim = op->getOutput()->getDims();
|
||||
int oh = odim[1], ow = odim[2];
|
||||
bool PReLU = op->getPReLU();
|
||||
// float paramReLU = op->getParamReLU();
|
||||
|
||||
auto opType = op->getOpType();
|
||||
|
||||
if (opType == OpType::Conv2dReduce) {
|
||||
conv2dreduce_kernel(input, bias, output, PReLU, n, h, w, f, r, s,
|
||||
oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
} else {
|
||||
convTranspose2dreduce_kernel(input, bias, output, PReLU, n, h, w, f,
|
||||
r, s, oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduce, DataType::Float32,
|
||||
Conv2dReduceCuda, "Conv2dReduce_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduceTranspose, DataType::Float32,
|
||||
Conv2dReduceCuda, "Conv2dReduceTranspose_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
|
@ -1,4 +1,5 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "nnet/dbg.h"
|
||||
|
||||
using dtype = float;
|
||||
|
||||
|
@ -40,18 +41,71 @@ __global__ void conv2dreduce_kernel_(float *__restrict__ input,
|
|||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
}
|
||||
__global__ void convTranspose2dreduce_kernel2_(
|
||||
float *__restrict__ input, float *__restrict__ bias,
|
||||
float *__restrict__ output, const bool PReLU, const int n, const int f,
|
||||
const int h, const int w, const int oh, const int ow, const int r,
|
||||
const int s, const int ph, const int pw, const int dh, const int dw,
|
||||
const int sh, const int sw) {
|
||||
int warp_id = (blockDim.x / 32) * blockIdx.x + threadIdx.x / 32;
|
||||
int lane = threadIdx.x % 32;
|
||||
int nid = warp_id / (f * oh * ow);
|
||||
int fid = (warp_id - nid * (f * oh * ow)) / (oh * ow);
|
||||
int hid = (warp_id - nid * (f * oh * ow) - fid * (oh * ow)) / ow;
|
||||
int wid = warp_id % ow;
|
||||
if (hid >= oh || wid >= ow || nid > n || fid > f)
|
||||
return;
|
||||
|
||||
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
|
||||
nchunck = h * hchunk;
|
||||
float *nfinput = input + nid * nchunck + fid * fchunck;
|
||||
// view as conv, the true ph and pw
|
||||
int tph = r - ph - 1, tpw = s - pw - 1;
|
||||
int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1;
|
||||
|
||||
float imm = 0.0;
|
||||
int ihst = hid - tph;
|
||||
int iwst = wid - tpw;
|
||||
for (int idx = lane; idx < r * s; idx += 32) {
|
||||
int ri = idx / s;
|
||||
int si = idx % s;
|
||||
int ihid = ihst + r - ri - 1;
|
||||
int iwid = iwst + s - si - 1;
|
||||
if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw &&
|
||||
(ihid % sh == 0) && (iwid % sw == 0)) {
|
||||
imm += *(nfinput + (ihid / sh) * hchunk + (iwid / sw) * wchunk +
|
||||
ri * s + si);
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 16; k > 0; k >>= 1) {
|
||||
imm += __shfl_down_sync(0xffffffff, imm, k); // sum
|
||||
}
|
||||
if (lane == 0) {
|
||||
if (bias) {
|
||||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : 0.0;
|
||||
}
|
||||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void convTranspose2dreduce_kernel_(
|
||||
float *__restrict__ input, float *__restrict__ bias,
|
||||
float *__restrict__ output, const bool PReLU, const int n, const int f,
|
||||
const int h, const int w, const int oh, const int ow, const int r,
|
||||
const int s, const int ph, const int pw, const int dh, const int dw,
|
||||
const int sh, const int sw) {
|
||||
const int sh, const int sw, const int block_x_num, const int block_y_num) {
|
||||
// assert dh = dw = 1
|
||||
int nid = blockIdx.x, fid = blockIdx.y;
|
||||
int hid = threadIdx.x, wid = threadIdx.y;
|
||||
int nid = blockIdx.x / block_x_num, fid = blockIdx.y / block_y_num;
|
||||
int hid = (blockIdx.x % block_x_num) * blockDim.x + threadIdx.x,
|
||||
wid = (blockIdx.y % block_y_num) * blockDim.y + threadIdx.y;
|
||||
if (hid >= oh || wid >= ow)
|
||||
return;
|
||||
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
|
||||
nchunck = n * hchunk;
|
||||
nchunck = h * hchunk;
|
||||
float *nfinput = input + nid * nchunck + fid * fchunck;
|
||||
// view as conv, the true ph and pw
|
||||
int tph = r - ph - 1, tpw = s - pw - 1;
|
||||
|
@ -162,8 +216,22 @@ void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
|||
reduce_4x4<<<(M * N + 127) / 128, 128>>>(input, output, act, n, f, oh,
|
||||
ow, h, w);
|
||||
} else {
|
||||
puts("why use this conv2dreduce");
|
||||
convTranspose2dreduce_kernel_<<<grid, block, 0>>>(
|
||||
// puts("why use this conv2dreduce");
|
||||
// block.x = 32;
|
||||
// block.y = 32;
|
||||
// int block_x_num = (oh + block.x - 1) / block.x;
|
||||
// int block_y_num = (ow + block.y - 1) / block.y;
|
||||
// grid.x = n * (block_x_num);
|
||||
// grid.y = f * (block_y_num);
|
||||
// convTranspose2dreduce_kernel_<<<grid, block, 0>>>(
|
||||
// input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw,
|
||||
// dh, dw, sh, sw, block_x_num, block_y_num);
|
||||
|
||||
block.x = 128;
|
||||
block.y = 1;
|
||||
grid.x = (n * f * ow * oh + block.x / 32 - 1) / (block.x / 32);
|
||||
grid.y = 1;
|
||||
convTranspose2dreduce_kernel2_<<<grid, block, 0>>>(
|
||||
input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw,
|
||||
dh, dw, sh, sw);
|
||||
}
|
||||
|
|
|
@ -31,8 +31,9 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
|
|||
n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
} else if (dim.size() == 3) {
|
||||
n = 1, c = dim[0], h = dim[1], w = dim[2];
|
||||
} else
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||
|
|
|
@ -86,6 +86,60 @@ Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId) {
|
|||
return g;
|
||||
}
|
||||
|
||||
// NHWC
|
||||
Graph getFSRCNNGraph(int batch, Runtime runtime) {
|
||||
// n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU
|
||||
const DetailedConfigs fsrcnn_config = {
|
||||
{batch, 1, 32, 32, 56, 5, 5, 1, 2, 1, true},
|
||||
{batch, 56, 32, 32, 12, 1, 1, 1, 0, 1, true},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false},
|
||||
{batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, true},
|
||||
{batch, 12, 32, 32, 56, 1, 1, 1, 0, 1, true},
|
||||
{batch, 56, 32, 32, 1, 9, 9, 1, 3, 4, false} // ConvTransNHWC
|
||||
// n, f, h, w, c, r, s, stride, pad, dilation, has_pReLU
|
||||
};
|
||||
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
Tensor input;
|
||||
{
|
||||
auto &[n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU] =
|
||||
fsrcnn_config[0];
|
||||
input = g->addTensor({batch, h, w, c}, DataType::Float32,
|
||||
TensorType::Input);
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)fsrcnn_config.size() - 1; ++i) {
|
||||
// auto [channel, kernelSize, pad, stride, tanh] = configs[i];
|
||||
auto &[n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU] =
|
||||
fsrcnn_config[i];
|
||||
IT_ASSERT(input->getDims()[3] == c);
|
||||
auto weight = g->addTensor({f, r, s, c}, DataType::Float32,
|
||||
TensorType::Initialized); // f, r, s, c
|
||||
input = g->addOp<ConvNHWCObj>(input, weight, nullptr, pad, pad, stride,
|
||||
stride, 1, 1)
|
||||
->getOutput();
|
||||
if (has_pReLU) {
|
||||
input = g->addOp<ReluObj>(input, nullptr)->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
// last operator is a ConvTransNHWC
|
||||
{
|
||||
auto &[n, f, h, w, c, r, s, stride, pad, dilation, has_pReLU] =
|
||||
fsrcnn_config[fsrcnn_config.size() - 1];
|
||||
IT_ASSERT(input->getDims()[3] == f);
|
||||
auto weight = g->addTensor({f, r, s, c}, DataType::Float32,
|
||||
TensorType::Initialized); // f, r, s, c
|
||||
input = g->addOp<ConvTransposed2dNHWCObj>(input, weight, nullptr, pad,
|
||||
pad, stride, stride, 1, 1)
|
||||
->getOutput();
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph getLongformer(Runtime runtime, int bs) {
|
||||
const int seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
|
@ -165,8 +219,8 @@ Graph getLongformer(Runtime runtime, int bs) {
|
|||
g->addOpWithOutputs<ReshapeObj>(q0, q1);
|
||||
g->addOpWithOutputs<ReshapeObj>(k0, k1);
|
||||
g->addOpWithOutputs<ReshapeObj>(v0, v1);
|
||||
// For example, when perm=(1, 0, 2), given an input tensor of shape (1, 2,
|
||||
// 3), the output shape will be (2, 1, 3).
|
||||
// For example, when perm=(1, 0, 2), given an input tensor of shape (1,
|
||||
// 2, 3), the output shape will be (2, 1, 3).
|
||||
g->addOpWithOutputs<TransposeObj>(q1, q2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<TransposeObj>(k1, k2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<TransposeObj>(v1, v2, vector{0, 2, 1, 3});
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "operators/GBMM.h"
|
||||
#include "operators/any.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/conv2dreduce.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
@ -98,6 +99,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (infini::Graph g = transformConv1xk(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (Graph g = transformG2bmm(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
|
@ -110,7 +115,12 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (infini::Graph g = transformConv1xk(computeOps[0])) {
|
||||
if (infini::Graph g = transformConvToGEMMReduce(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
|
||||
if (infini::Graph g = transformConvTranposeToGEMMReduce(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
|
@ -522,7 +532,9 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
|
|||
if (h != 1 || w != 1)
|
||||
return {};
|
||||
IT_ASSERT_TODO(ph == pw);
|
||||
IT_ASSERT_TODO(tie(sh, sw) == tuple(1, 1));
|
||||
if (tie(sh, sw) != tuple(1, 1)) {
|
||||
return nullptr;
|
||||
}
|
||||
IT_ASSERT_TODO(tie(dh, dw) == tuple(1, 1));
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
// NHWF
|
||||
|
@ -543,6 +555,80 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
|
|||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConvToGEMMReduce(Operator _op) {
|
||||
auto op = as<ConvNHWCObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
const auto &A = op->getInputs()[0];
|
||||
const auto &W = op->getInputs()[1];
|
||||
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const Shape inputDims = op->getInputs(0)->getDims();
|
||||
const Shape weightDims = op->getInputs(1)->getDims();
|
||||
const Shape outputDims = op->getOutput()->getDims();
|
||||
IT_ASSERT(weightDims[0] == f);
|
||||
IT_ASSERT(weightDims[1] == r);
|
||||
IT_ASSERT(weightDims[2] == s);
|
||||
IT_ASSERT(weightDims[3] == c);
|
||||
IT_ASSERT(inputDims[0] == n);
|
||||
IT_ASSERT(inputDims[1] == h);
|
||||
IT_ASSERT(inputDims[2] == w);
|
||||
IT_ASSERT(inputDims[3] == c);
|
||||
const DataType dtype = A->getDType();
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto newA = g->addTensor(
|
||||
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype);
|
||||
|
||||
// // If use Matmul with transpose 0,0
|
||||
// auto newW = g->addTensor(
|
||||
// {weightDims[3], weightDims[0] * weightDims[1] * weightDims[2]}, dtype);
|
||||
|
||||
// If use Matmul with transpose 0, 1
|
||||
auto newW = g->addTensor(
|
||||
{weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
|
||||
dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(W), newW, newW->getDims());
|
||||
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 1)->getOutput();
|
||||
auto new1 = g->addTensor({n, h, w, f, r, s}, dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(newO, new1, new1->getDims());
|
||||
g->addOpWithOutputs<Conv2dReduce>(
|
||||
new1, nullptr, g->cloneTensor(op->getOutput()), false, 0.f, ph, pw);
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConvTranposeToGEMMReduce(Operator _op) {
|
||||
auto op = as<ConvTransposed2dNHWCObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
const auto &A = op->getInputs()[0];
|
||||
const auto &W = op->getInputs()[1];
|
||||
// f is the de-facto input channel for ConvTranspose
|
||||
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const Shape inputDims = op->getInputs(0)->getDims();
|
||||
const Shape weightDims = op->getInputs(1)->getDims();
|
||||
const Shape outputDims = op->getOutput()->getDims();
|
||||
const DataType dtype = A->getDType();
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto newA = g->addTensor( // [N,H,W,F]
|
||||
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype);
|
||||
auto newW = g->addTensor( // [F, CRS]
|
||||
{weightDims[0], weightDims[1] * weightDims[2] * weightDims[3]},
|
||||
dtype); // HACK: this should be a transpose
|
||||
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(W), newW, newW->getDims());
|
||||
// newO [NHW, CRS]
|
||||
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 0)->getOutput();
|
||||
auto new1 = g->addTensor({n, h, w, c, r, s}, dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(newO, new1, new1->getDims());
|
||||
// [NHW, CRS] -> [N,H,W,C]
|
||||
g->addOpWithOutputs<Conv2dReduceTranspose>(
|
||||
new1, nullptr, g->cloneTensor(op->getOutput()), false, 0.f, ph, pw);
|
||||
return g;
|
||||
}
|
||||
|
||||
// Graph NMutator::transformConvtransposed(Operator _op) {
|
||||
// auto op = as<ConvTransposed2dNHWCObj>(_op);
|
||||
// if (!op)
|
||||
|
|
|
@ -114,6 +114,75 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
|||
return {{{on, oc, oh, ow}}};
|
||||
}
|
||||
|
||||
void ConvNHWCObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], c = input->getDims()[3], h = input->getDims()[1],
|
||||
w = input->getDims()[2], f = weight->getDims()[0], r = weight->getDims()[1],
|
||||
s = weight->getDims()[2];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
int ow = w / sw;
|
||||
ph = (h - oh * sh + (r - sh) * dh) / 2;
|
||||
pw = (w - ow * sw + (s - sw) * dw) / 2;
|
||||
} else if (mode == PaddingMode::Valid) {
|
||||
ph = pw = 0;
|
||||
}
|
||||
}
|
||||
|
||||
ConvNHWCObj::ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::ConvNHWC, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ConvNHWCObj::ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::ConvNHWC, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ConvNHWCObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto h = input->getDims()[1];
|
||||
auto w = input->getDims()[2];
|
||||
auto f = weight->getDims()[0];
|
||||
auto r = weight->getDims()[1];
|
||||
auto s = weight->getDims()[2];
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
if (input->getDims()[3] % weight->getDims()[3] != 0)
|
||||
return {};
|
||||
// Set padding size
|
||||
if (padding == PaddingMode::Other) {
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
} else if (padding == PaddingMode::Same) {
|
||||
oh = h / sh;
|
||||
ow = w / sw;
|
||||
// ph = (h - oh * sh + (r - sh) * dh) / 2;
|
||||
// pw = (w - ow * sw + (s - sw) * dw) / 2;
|
||||
} else if (padding == PaddingMode::Valid) {
|
||||
int ph = 0;
|
||||
int pw = 0;
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
}
|
||||
return {{{on, oh, ow, oc}}};
|
||||
}
|
||||
|
||||
ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
||||
Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw,
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
#include "operators/conv2dreduce.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
Conv2dReduceBase::Conv2dReduceBase(OpType opType, Tensor input, Tensor bias_,
|
||||
Tensor output, bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_, int dh_,
|
||||
int dw_)
|
||||
: OperatorObj(opType, {input}, {output}), bias(bias_), ph(ph_), pw(pw_),
|
||||
sh(sh_), sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_), paramReLU(paramReLU_) {
|
||||
// expect input shape is (n, h, w, f, r, s)
|
||||
auto inputShape = input->getDims();
|
||||
IT_ASSERT(inputShape.size() == 6);
|
||||
n = inputShape[0];
|
||||
h = inputShape[1];
|
||||
w = inputShape[2];
|
||||
f = inputShape[3];
|
||||
r = inputShape[4];
|
||||
s = inputShape[5];
|
||||
|
||||
if (bias) {
|
||||
auto biasShape = bias->getDims();
|
||||
IT_ASSERT(biasShape.size() == 1);
|
||||
IT_ASSERT(biasShape[0] == f);
|
||||
}
|
||||
}
|
||||
|
||||
std::string Conv2dReduceBase::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
if (inputs.size() == 2) {
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << vecToString(inputs[1]->getDims()) << ",";
|
||||
} else {
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
}
|
||||
os << "p=[" << ph << "," << pw << "],";
|
||||
os << "s=[" << sh << "," << sw << "],";
|
||||
os << "d=[" << dh << "," << dw << "],";
|
||||
os << "PReLU=" << (PReLU ? "true" : "false") << ",";
|
||||
// os << "act=" << enum_to_underlying(act) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
if (bias != nullptr) {
|
||||
os << "bias=" << bias->getGuid() << ",";
|
||||
}
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
std::vector<int> Conv2dReduceBase::getWorkloadVector() const {
|
||||
return {enum_to_underlying(type), n, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
std::vector<int> Conv2dReduceBase::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type), ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
Conv2dReduce::Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias,
|
||||
Tensor output, bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_, int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_,
|
||||
paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
Conv2dReduce::inferShape(const TensorVec &inputs) const {
|
||||
// const auto &input = inputs[0], &bias = inputs[1];
|
||||
int on = n, of = f;
|
||||
int oh = (h + ph * 2 - dh * (r - 1) - 1) / sh + 1;
|
||||
int ow = (w + pw * 2 - dw * (s - 1) - 1) / sw + 1;
|
||||
|
||||
return {{{on, oh, ow, of}}};
|
||||
}
|
||||
|
||||
Conv2dReduceTranspose::Conv2dReduceTranspose(GraphObj *graph, Tensor input,
|
||||
Tensor bias, Tensor output,
|
||||
bool PReLU_, float paramReLU_,
|
||||
int ph_, int pw_, int sh_, int sw_,
|
||||
int dh_, int dw_)
|
||||
: Conv2dReduceBase(OpType::Conv2dReduceTranspose, input, bias, output,
|
||||
PReLU_, paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) {
|
||||
IT_ASSERT(dh_ == 1);
|
||||
IT_ASSERT(dw_ == 1);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
Conv2dReduceTranspose::inferShape(const TensorVec &inputs) const {
|
||||
// const auto &input = inputs[0], &bias = inputs[1];
|
||||
int on = n, of = f;
|
||||
int oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + 1;
|
||||
int ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + 1;
|
||||
|
||||
return {{{on, oh, ow, of}}};
|
||||
}
|
||||
} // namespace infini
|
|
@ -43,6 +43,42 @@ void testConvCudnn(
|
|||
gCuda->print();
|
||||
}
|
||||
|
||||
void testConvNHWCCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 4, 4, 3}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvNHWCObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
o0Cpu->print();
|
||||
o0Cpu->printData();
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, run) {
|
||||
testConvCudnn(OneGenerator(),
|
||||
vector<float>{12, 12, 18, 18, 12, 12, 18, 18});
|
||||
|
@ -51,6 +87,14 @@ TEST(cuDNN_Conv, run) {
|
|||
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, runNHWC) {
|
||||
testConvNHWCCudnn(OneGenerator(),
|
||||
vector<float>{12., 12., 12., 12., 18., 18., 18., 18.});
|
||||
testConvNHWCCudnn(
|
||||
IncrementalGenerator(),
|
||||
vector<float>{3350, 7562, 2306, 5546, 9480, 24546, 7185, 20793});
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, tune) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
|
|
@ -30,7 +30,7 @@ def run_and_evaluate(runtime, g):
|
|||
runtime.run(g, True)
|
||||
print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
|
||||
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
|
||||
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}')
|
||||
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 100)}')
|
||||
|
||||
|
||||
def run_graph_get_output_as_torch_tensor(runtime, g):
|
||||
|
@ -111,6 +111,23 @@ def construct_conv(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
|||
handler.conv(input, w, None, pad, pad, stride, stride, dilation, dilation)
|
||||
return handler.getGraph()
|
||||
|
||||
def construct_conv_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
||||
handler = ft.GraphHandler(runtime)
|
||||
# input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||
# w = handler.tensor([12, 56, 1, 1], tensor_type=ft.TensorType.Initialized)
|
||||
# handler.conv(input, w, None, 0, 0, 1, 1, 1, 1)
|
||||
input = handler.tensor([n, h, w, c], tensor_type=ft.TensorType.Input)
|
||||
w = handler.tensor([f, r, s, c], tensor_type=ft.TensorType.Initialized)
|
||||
handler.convNHWC(input, w, None, pad, pad, stride, stride, dilation, dilation)
|
||||
return handler.getGraph()
|
||||
|
||||
def construct_convtranposed_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
||||
handler = ft.GraphHandler(runtime)
|
||||
input = handler.tensor([n, h, w, c], tensor_type=ft.TensorType.Input)
|
||||
w = handler.tensor([f, r, s, c], tensor_type=ft.TensorType.Initialized)
|
||||
handler.convtransposed2dNHWC(input, w, None, pad, pad, stride, stride, dilation, dilation)
|
||||
return handler.getGraph()
|
||||
|
||||
|
||||
def export_op_level_onnx(runtime):
|
||||
graphs = [
|
||||
|
@ -134,9 +151,13 @@ if __name__ == "__main__":
|
|||
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3
|
||||
# ft.getGANGraph(batch, runtime, 5, 1)
|
||||
# (ft.getLongformer(runtime, 1), 'longformer.bs1'),
|
||||
(ft.getLongformer(runtime, 16), 'longformer.bs16'),
|
||||
# (ft.getLongformer(runtime, 16), 'longformer.bs16'),
|
||||
# construct_convTranspose2d(runtime)
|
||||
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
# (ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
|
||||
# (ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
|
||||
# (construct_conv_nhwc(runtime, 1, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1')
|
||||
(load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'),
|
||||
]
|
||||
|
||||
for original_g, name in graphs:
|
||||
|
|
Loading…
Reference in New Issue