forked from jiuyuan/InfiniTensor
Add: ConvTransposed (#33)
* Add: convTransposed2d operator * Fix: IT_ASSERT namespace * Add: nullptr check in as for Ref * Fix: conv transpose operator and kernel * Fix: makes PerfEngine singleton * Add: ConvTransposed test * Fix: rebase to master (PerfRecord shared_ptr) * Revert: Ref with nullptr check Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
d39328afce
commit
8f67a5cc76
|
@ -42,7 +42,7 @@ using HashType = uint64_t; // compatible with std::hash
|
|||
#define _IT_ASSERT_2(name, info) \
|
||||
(static_cast<bool>(name) \
|
||||
? void(0) \
|
||||
: throw infini::Exception( \
|
||||
: throw ::infini::Exception( \
|
||||
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
||||
"] Assertion failed (" + #name + "): " + #info))
|
||||
#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, "");
|
||||
|
|
|
@ -9,6 +9,10 @@ class PerfEngine {
|
|||
// TODO: Key should be OpPerfKey + Context(maybe implicat) to support
|
||||
// multiple candiate kernels.
|
||||
using Key = std::pair<KernelAttrs, OpPerfKey>;
|
||||
PerfEngine() = default;
|
||||
// PerfEngine is singleton
|
||||
PerfEngine(PerfEngine &other) = delete;
|
||||
PerfEngine &operator=(PerfEngine const &) = delete;
|
||||
|
||||
private:
|
||||
map<Key, PerfRecord> data;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
#include <functional> // hash
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
class ConvObj : public OperatorObj {
|
||||
class ConvBaseObj : public OperatorObj {
|
||||
public:
|
||||
// When PaddingMode is Other, ConvObj will use padding size (ph, pw)
|
||||
// Otherwise, padding size (ph, pw) will be computed by padding mode
|
||||
|
@ -13,34 +13,33 @@ class ConvObj : public OperatorObj {
|
|||
Valid,
|
||||
};
|
||||
|
||||
private:
|
||||
protected:
|
||||
int ph, pw;
|
||||
int sh, sw;
|
||||
int dh, dw;
|
||||
ActType act;
|
||||
PaddingMode padding;
|
||||
// auxiliary attributes
|
||||
int n, c, h, w, f, r, s;
|
||||
// auxiliary attributes. Descripitions stand on a forward perspective,
|
||||
// i.e., convTransposed2d is not regarded as the backward of conv2d.
|
||||
int n; // batch size
|
||||
int c; // input/output channel for conv2d/convTransposed2d
|
||||
int h, w; // input shape (same for conv2d and convTranposed2d)
|
||||
int f; // output/input channel for conv2d/convTransposed2d
|
||||
int r, s; // weight shape
|
||||
|
||||
public:
|
||||
// Constructors for explicitly setting padding size
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
// Constructors for setting padding mode
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw, const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD);
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD);
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
Tensor getBias() const { return inputs[2]; }
|
||||
ActType getAct() const { return act; }
|
||||
PaddingMode getPaddingMode() const { return padding; }
|
||||
pair<int, int> inferPaddingSize() const;
|
||||
|
||||
|
@ -53,7 +52,7 @@ class ConvObj : public OperatorObj {
|
|||
auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); }
|
||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
int getChannelPerGroup() const { return inputs[1]->getDims()[1]; }
|
||||
int getNumGroups() const { return c / getChannelPerGroup(); }
|
||||
virtual int getNumGroups() const = 0;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
@ -62,7 +61,56 @@ class ConvObj : public OperatorObj {
|
|||
* @brief Set the Auxilary Attributes: nchwrfs and padding (ph, pw) if
|
||||
* padding mode is set. This function should be called in constructor.
|
||||
*/
|
||||
void setAuxilaryAttributes(PaddingMode mode);
|
||||
virtual void setAuxilaryAttributes(PaddingMode mode) = 0;
|
||||
};
|
||||
|
||||
class ConvObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
// Constructors for setting padding mode
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
class ConvTransposed2dObj : public ConvBaseObj {
|
||||
private:
|
||||
int oph, opw;
|
||||
int group;
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output, int ph, int pw, int sh = 1, int sw = 1,
|
||||
int dh = 1, int dw = 1, int oph = 0, int opw = 0,
|
||||
int group = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
// Constructors for setting padding mode
|
||||
ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output, PaddingMode mode = PaddingMode::Same,
|
||||
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
int oph = 0, int opw = 0, int group = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return group; }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -15,7 +15,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
if (!tune && profiling)
|
||||
IT_TODO_HALT();
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto perfEngine = PerfEngine::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
// Statistics
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
|
@ -63,7 +63,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
|
||||
double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto perfEngine = PerfEngine::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
// Statistics
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
|
|
|
@ -7,7 +7,7 @@ namespace infini {
|
|||
void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto perfEngine = PerfEngine::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
std::map<OpType, int> opCnt;
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
namespace infini {
|
||||
|
||||
struct ConvTransposedCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
int algo = 0; // cudnnConvolutionBwdDataAlgo_t
|
||||
int mode = 1;
|
||||
size_t workspaceSize = 100000;
|
||||
bool fuseAct = false;
|
||||
};
|
||||
using ConvTransposedCuDnnPerfRecord = Ref<ConvTransposedCuDnnPerfRecordObj>;
|
||||
|
||||
static constexpr int N_ALGO = 6;
|
||||
static_assert(N_ALGO == int(CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT),
|
||||
"Unsupported cuDNN version");
|
||||
static const cudnnConvolutionBwdDataAlgo_t ALGOS[N_ALGO] = {
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, /* non-deterministic */
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
|
||||
static const char algo_name[N_ALGO][50] = {
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", /* non-deterministic */
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"};
|
||||
static const char math_types[3][50] = {"CUDNN_DEFAULT_MATH",
|
||||
"CUDNN_TENSOR_OP_MATH",
|
||||
"CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION"};
|
||||
static constexpr int N_MODE = 2;
|
||||
static constexpr cudnnConvolutionMode_t MODES[N_MODE] = {
|
||||
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||
|
||||
class convBackwardDataCudnn : public Kernel {
|
||||
|
||||
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
|
||||
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
cudnnTensorDescriptor_t>
|
||||
createCuDNNDescriptor(
|
||||
const Ref<ConvTransposed2dObj> &op,
|
||||
const ConvTransposedCuDnnPerfRecordObj &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
if (op->getInputs().size() > 2) // Bias is not supported yet
|
||||
IT_TODO_HALT();
|
||||
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void
|
||||
// *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int channelsPerGrp = op->getChannelPerGroup();
|
||||
const int g = op->getNumGroups();
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
// IT_ASSERT(g == 1, "Group convolution is not supported yet");
|
||||
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, f, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||
CUDNN_TENSOR_NCHW, f,
|
||||
channelsPerGrp, r, s));
|
||||
// get bias
|
||||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
|
||||
// get convlution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
||||
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||
convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode],
|
||||
CUDNN_DATA_FLOAT));
|
||||
if (g > 1) {
|
||||
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
||||
}
|
||||
|
||||
// get activation descriptor
|
||||
cudnnActivationDescriptor_t actDesc;
|
||||
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
|
||||
// NOT_PROPAGATE_NAN is requierd by
|
||||
// cudnnConvolotionBiasActivationForward
|
||||
switch (op->getAct()) {
|
||||
case ActType::Relu:
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
case ActType::Sigmoid:
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
case ActType::None:
|
||||
checkCudnnError(
|
||||
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
|
||||
CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
const auto &outputShape = op->getOutput()->getDims();
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, outputShape[0],
|
||||
outputShape[1], outputShape[2], outputShape[3]));
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
}
|
||||
|
||||
bool cuDNNUnfused(const Ref<ConvTransposed2dObj> &op,
|
||||
const ConvTransposedCuDnnPerfRecordObj &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
size_t wsSize = record.workspaceSize;
|
||||
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionBackwardData(
|
||||
context->cudnnHandle(), &alpha, knDesc, knData, inDesc, inData,
|
||||
convDesc, ALGOS[record.algo], wsData, wsSize, &beta, outDesc,
|
||||
outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return false;
|
||||
// TODO:
|
||||
// // bias
|
||||
// if (bias != nullptr) {
|
||||
// auto sz = op.getOutputs()[0]->size();
|
||||
// // TODO: element wise
|
||||
// t += sz * 2 / 400;
|
||||
// }
|
||||
// // act
|
||||
// if (act != None) {
|
||||
// stat = cudnnActivationForward(cudnnHandle(), actDesc,
|
||||
// &alpha, inDesc, inData,
|
||||
// &beta, outDesc, outData);
|
||||
// checkCudaError(cudaDeviceSynchronize());
|
||||
// end = ch::high_resolution_clock::now();
|
||||
// if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
// durtime = INFINITY;
|
||||
// break;
|
||||
// }
|
||||
// t +=
|
||||
// ch::duration_cast<ch::duration<double>>(end -
|
||||
// beg).count() * 1000; // ms
|
||||
// }
|
||||
|
||||
// best = ConvResult{durtime, ALGOS[i], wsSize, false};
|
||||
|
||||
// // w/ bias & act
|
||||
// for (int j = 0; j < rounds + warmupRounds; ++j) {
|
||||
// cudnnStatus_t stat;
|
||||
// if (j == warmupRounds) {
|
||||
// checkCudaError(cudaDeviceSynchronize());
|
||||
// beg = ch::high_resolution_clock::now();
|
||||
// }
|
||||
// stat = cudnnConvolutionBiasActivationForward(
|
||||
// cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||
// knData, convDesc, ALGOS[i], wsData, wsSize, &beta,
|
||||
// outDesc, outData, biasDesc, biasData, actDesc,
|
||||
// outDesc, outData);
|
||||
// if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
// // checkCudnnError(stat);
|
||||
// // Do not checkCudnnError since not all algorithms
|
||||
// are
|
||||
// // supported
|
||||
// durtime_fuse = INFINITY;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not
|
||||
// state whether sync is required before destories.
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
// with paramters in default ctor
|
||||
auto record = make_ref<ConvTransposedCuDnnPerfRecordObj>();
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvTransposedCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
// Both modes have the same performance. Only run
|
||||
// cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
ConvTransposedCuDnnPerfRecordObj record;
|
||||
record.mode = mode;
|
||||
record.algo = algo;
|
||||
cudnnStatus_t stat;
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
|
||||
// get workspace
|
||||
stat = cudnnGetConvolutionBackwardDataWorkspaceSize(
|
||||
context->cudnnHandle(), knDesc, inDesc, convDesc, outDesc,
|
||||
ALGOS[record.algo], &record.workspaceSize);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
continue;
|
||||
|
||||
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionBackwardData(
|
||||
context->cudnnHandle(), &alpha, knDesc, knData, inDesc,
|
||||
inData, convDesc, ALGOS[record.algo], wsData,
|
||||
record.workspaceSize, &beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
continue;
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
cudnnConvolutionBackwardData(
|
||||
context->cudnnHandle(), &alpha, knDesc, knData,
|
||||
inDesc, inData, convDesc, ALGOS[record.algo],
|
||||
wsData, record.workspaceSize, &beta, outDesc,
|
||||
outData);
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
// printf("mode:%d algo:%d :%.8lf\n", mode, algo,
|
||||
// record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
ret = record;
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
}
|
||||
}
|
||||
// printf("the best algo is %d, the best conv mode is %d\n",
|
||||
// ret.algo,
|
||||
// ret.mode);
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return make_ref<ConvTransposedCuDnnPerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
auto record = as<ConvTransposedCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, *record, context);
|
||||
IT_ASSERT(success);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTrans, DataType::Float32,
|
||||
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
|
@ -2,28 +2,24 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw,
|
||||
[[maybe_unused]] Tensor bias, ActType act)
|
||||
: OperatorObj(OpType::Conv, {input, weight}, {output}), ph(ph), pw(pw),
|
||||
sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(PaddingMode::Other) {
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD)
|
||||
: OperatorObj(opType, inputs, {output}), ph(ph), pw(pw), sh(sh), sw(sw),
|
||||
dh(dh), dw(dw), padding(PaddingMode::Other) {}
|
||||
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
[[maybe_unused]] Tensor bias, ActType act)
|
||||
: OperatorObj(OpType::Conv, {input, weight}, {output}), ph(-1), pw(-1),
|
||||
sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(mode) {
|
||||
const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD)
|
||||
: OperatorObj(opType, inputs, {output}), ph(-1), pw(-1), sh(sh), sw(sw),
|
||||
dh(dh), dw(dw), padding(mode) {
|
||||
IT_ASSERT(mode != PaddingMode::Other);
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
string ConvObj::toString() const {
|
||||
string ConvBaseObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Conv[" << getGuid() << "]";
|
||||
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
if (inputs.size() == 2) {
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
|
@ -32,13 +28,63 @@ string ConvObj::toString() const {
|
|||
os << "p=[" << ph << "," << pw << "],";
|
||||
os << "s=[" << sh << "," << sw << "],";
|
||||
os << "d=[" << dh << "," << dw << "],";
|
||||
os << "act=" << enum_to_underlying(act) << ",";
|
||||
// os << "act=" << enum_to_underlying(act) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "weight=" << inputs[1]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> ConvBaseObj::getWorkloadVector() const {
|
||||
return {
|
||||
enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
vector<int> ConvBaseObj::getOpAttrVector() const {
|
||||
IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
|
||||
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
||||
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], c = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3], f = weight->getDims()[0], r = weight->getDims()[2],
|
||||
s = weight->getDims()[3];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
int ow = w / sw;
|
||||
ph = (h - oh * sh + (r - sh) * dh) / 2;
|
||||
pw = (w - ow * sw + (s - sw) * dw) / 2;
|
||||
} else if (mode == PaddingMode::Valid) {
|
||||
ph = pw = 0;
|
||||
}
|
||||
}
|
||||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
|
@ -70,23 +116,60 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
|||
return {{{on, oc, oh, ow}}};
|
||||
}
|
||||
|
||||
vector<int> ConvObj::getWorkloadVector() const {
|
||||
return {
|
||||
enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw,
|
||||
enum_to_underlying(act)};
|
||||
ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
||||
Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw,
|
||||
int oph, int opw, int group,
|
||||
Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw,
|
||||
dh, dw, output, weight),
|
||||
oph(oph), opw(opw), group(group), act(act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
vector<int> ConvObj::getOpAttrVector() const {
|
||||
IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
|
||||
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw,
|
||||
enum_to_underlying(act)};
|
||||
ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
||||
Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw,
|
||||
int dh, int dw, int oph, int opw,
|
||||
int group, Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
|
||||
dw, output, weight),
|
||||
oph(oph), opw(opw), group(group), act(act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
n = inputs[0]->getDims()[0], c = inputs[0]->getDims()[1],
|
||||
h = inputs[0]->getDims()[2], w = inputs[0]->getDims()[3],
|
||||
f = inputs[1]->getDims()[0], r = inputs[1]->getDims()[2],
|
||||
s = inputs[1]->getDims()[3];
|
||||
optional<vector<Shape>>
|
||||
ConvTransposed2dObj::inferShape(const TensorVec &inputs) const {
|
||||
const Tensor &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto f = input->getDims()[1];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
auto c = weight->getDims()[1];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
if (f != weight->getDims()[0])
|
||||
return {};
|
||||
|
||||
int on = n, oc = c * group;
|
||||
int oh = 0, ow = 0;
|
||||
oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + oph + 1;
|
||||
ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + opw + 1;
|
||||
return {{{on, oc, oh, ow}}};
|
||||
}
|
||||
|
||||
void ConvTransposed2dObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], f = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3], c = weight->getDims()[0], r = weight->getDims()[2],
|
||||
s = weight->getDims()[3];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
int ow = w / sw;
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(ConvTransposed, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
{ // No pad: InfoGAN ConvTranspose_0
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 228, 1, 1});
|
||||
Tensor w0 = g->addTensor({228, 448, 2, 2});
|
||||
auto conv = g->addOp<ConvTransposed2dObj>(i0, w0, nullptr, 0, 0);
|
||||
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
|
||||
}
|
||||
{ // Padded, Strided: InfoGAN ConvTranspose_3
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 448, 2, 2});
|
||||
Tensor w0 = g->addTensor({448, 256, 4, 4});
|
||||
auto conv = g->addOp<ConvTransposed2dObj>(i0, w0, nullptr, 1, 1, 2, 2);
|
||||
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 256, 4, 4}));
|
||||
}
|
||||
{ // With output padding: GCN ConvTranspose_224
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 21, 7, 7});
|
||||
Tensor w0 = g->addTensor({21, 21, 3, 3});
|
||||
auto conv = g->addOp<ConvTransposed2dObj>(i0, w0, nullptr, 1, 1, 2, 2,
|
||||
1, 1, 1, 1);
|
||||
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 21, 14, 14}));
|
||||
}
|
||||
}
|
||||
|
||||
void testConvTransposedCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr,
|
||||
padding, padding, stride,
|
||||
stride, dilation, dilation);
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(ConvTransposed, cuDNN) {
|
||||
testConvTransposedCudnn(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
12., 18., 16., 8., 30., 36., 42.,
|
||||
32., 16., 54., 60., 66., 48., 24.,
|
||||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(ConvTransposed, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 448, 2, 2}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({448, 256, 4, 4}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
||||
ASSERT_TRUE(perfData.has_value());
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue