Fix: PerfRecord in shared pointers (#31)

* Fix: PerfData in a shared pointer

* Add: abstraction for kernels without configuration

Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
zhengly123 2022-09-18 20:27:18 +08:00 committed by GitHub
parent 6ac106cba4
commit d39328afce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 125 additions and 217 deletions

5
.gitignore vendored
View File

@ -34,4 +34,7 @@
build/
build_debug/
.vscode/
.vscode/
# python
*.pyc

View File

@ -7,13 +7,14 @@ namespace infini {
class RuntimeObj; // Forward declaration for Kernel::compute
struct PerfRecord {
PerfRecord(){};
PerfRecord(double time) : time(time){};
virtual ~PerfRecord() {}
struct PerfRecordObj {
PerfRecordObj(){};
PerfRecordObj(double time) : time(time){};
virtual ~PerfRecordObj() {}
double time = 0; // in milliseconds
};
using PerfRecord = Ref<PerfRecordObj>;
class Kernel {
public:
@ -73,6 +74,21 @@ class KernelRegistry {
}
};
class CpuKernelWithoutConfig : public Kernel {
public:
void compute(const Operator &op, const PerfRecord &record,
const RuntimeObj *context) const override {
compute(op, context);
}
virtual void compute(const Operator &op,
const RuntimeObj *context) const = 0;
// Premise: op is idempotent since it is called multiple times.
virtual PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, context); }));
}
};
} // namespace infini
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \

View File

@ -0,0 +1,24 @@
#pragma once
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class CudaKernelWithoutConfig : public Kernel {
public:
virtual void compute(const Operator &op, const PerfRecord &record,
const RuntimeObj *context) const {
compute(op, context);
}
virtual void compute(const Operator &op,
const RuntimeObj *context) const = 0;
// Premise: op is idempotent since it is called multiple times.
virtual PerfRecord tune(const Operator &op,
const RuntimeObj *_context) const {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
[&]() { context->sync(); }));
}
};
} // namespace infini

View File

@ -83,7 +83,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
} else
record = *perfData;
double t = record.time;
double t = record->time;
totalTime += t;
if (profiling) {
op->print();

View File

@ -30,7 +30,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
} else
record = *perfData;
double t = record.time;
double t = record->time;
totalTime += t;
if (profiling) {

View File

@ -3,8 +3,8 @@
namespace infini {
template <typename T> class NaiveConv : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<ConvObj>(_op);
T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
@ -45,15 +45,6 @@ template <typename T> class NaiveConv : public Kernel {
}
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
return PerfRecord(timeit([&]() { compute(op, context); }));
}
};
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32,

View File

@ -2,9 +2,9 @@
#include "core/kernel.h"
namespace infini {
template <typename T> class NativeElementWise : public Kernel {
template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
virtual T doCompute(T val0, T val1) const = 0;
void compute(const Operator &_op, const PerfRecord &record,
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<ElementWiseObj>(_op);
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
@ -24,16 +24,6 @@ template <typename T> class NativeElementWise : public Kernel {
outptr[offset] = doCompute(inptr0[offset], inptr1[offset]);
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveAdd : public NativeElementWise<T> {

View File

@ -3,8 +3,8 @@
namespace infini {
template <typename T> class NaiveMatmul : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
@ -24,17 +24,6 @@ template <typename T> class NaiveMatmul : public Kernel {
}
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord ret;
ret.time = timeit([&]() { compute(op, context); });
return ret;
}
};
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::UInt32,

View File

@ -73,7 +73,7 @@ class MemboundInterpreter : public Kernel {
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
return PerfRecord(
return make_ref<PerfRecordObj>(
timeit([&]() { compute(op, context); }, []() {}, 0, 1));
}
};

View File

@ -2,10 +2,10 @@
#include "core/kernel.h"
namespace infini {
template <typename T> class NativePooling : public Kernel {
template <typename T> class NativePooling : public CpuKernelWithoutConfig {
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) const = 0;
void compute(const Operator &_op, const PerfRecord &record,
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
@ -32,16 +32,6 @@ template <typename T> class NativePooling : public Kernel {
}
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveMaxPool : public NativePooling<T> {

View File

@ -3,9 +3,9 @@
#include "core/kernel.h"
namespace infini {
template <typename T> class NativeUnary : public Kernel {
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
virtual T doCompute(T val) const = 0;
void compute(const Operator &_op, const PerfRecord &record,
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
@ -17,20 +17,10 @@ template <typename T> class NativeUnary : public Kernel {
outptr[offset] = doCompute(inptr[offset]);
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveSoftmax : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
@ -46,16 +36,6 @@ template <typename T> class NaiveSoftmax : public Kernel {
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveRelu : public NativeUnary<T> {

View File

@ -1,5 +1,5 @@
#include "operators/G2BMM.h"
#include "core/kernel.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "custom_ops.h"
#include <chrono>
@ -7,7 +7,7 @@
#include <tuple>
namespace infini {
class G2BMMCudnn : public Kernel {
class G2BMMCudnn : public CudaKernelWithoutConfig {
bool g2bmmKernel(const Ref<G2BMMObj> &op,
const CudaRuntimeObj *context) const {
@ -25,31 +25,27 @@ class G2BMMCudnn : public Kernel {
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
PerfRecord record;
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord record;
auto op = as<G2BMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
record.time = std::numeric_limits<double>::max();
auto record =
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
const auto [warmupRounds, timingRounds] =
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
double tmp =
timeit([&]() { g2bmmKernel(op, context); },
[&]() { context->sync(); }, warmupRounds, timingRounds);
if (tmp < record.time)
record.time = tmp;
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
if (tmp < record->time)
record->time = tmp;
IT_ASSERT(record->time < std::numeric_limits<double>::max(),
"Error occured "
"during runtime");
return record;
}
void compute(const Operator &_op, const PerfRecord &_record,
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<G2BMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);

View File

@ -1,5 +1,5 @@
#include "operators/GBMM.h"
#include "core/kernel.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "custom_ops.h"
#include <chrono>
@ -8,7 +8,7 @@
namespace infini {
class GBMMCudnn : public Kernel {
class GBMMCudnn : public CudaKernelWithoutConfig {
bool gbmmKernel(const Ref<GBMMObj> &op,
const CudaRuntimeObj *context) const {
@ -25,32 +25,28 @@ class GBMMCudnn : public Kernel {
// checkCudaError(cudaDeviceSynchronize());
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
PerfRecord record;
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord record;
auto op = as<GBMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
record.time = std::numeric_limits<double>::max();
auto record =
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
const auto [warmupRounds, timingRounds] =
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
double tmp =
timeit([&]() { gbmmKernel(op, context); },
[&]() { context->sync(); }, warmupRounds, timingRounds);
if (tmp < record.time)
record.time = tmp;
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
if (tmp < record->time)
record->time = tmp;
IT_ASSERT(record->time < std::numeric_limits<double>::max(),
"Error occured "
"during runtime");
return record;
}
void compute(const Operator &_op, const PerfRecord &_record,
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<GBMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);

View File

@ -21,12 +21,13 @@ static constexpr int N_MODE = 2;
static constexpr cudnnConvolutionMode_t MODES[N_MODE] = {
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
struct ConvCuDnnPerfRecord : public PerfRecord {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
int algo = 0; // cudnnConvolutionFwdAlgo_t
int mode = 1;
size_t workspaceSize = 100000;
bool fuseAct = false;
};
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
class convCudnn : public Kernel {
@ -73,7 +74,7 @@ class convCudnn : public Kernel {
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode],
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
CUDNN_DATA_FLOAT));
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
@ -125,13 +126,13 @@ class convCudnn : public Kernel {
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
size_t wsSize = record.workspaceSize;
size_t wsSize = record->workspaceSize;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
inData, knDesc, knData, convDesc,
ALGOS[record.algo], wsData, wsSize,
ALGOS[record->algo], wsData, wsSize,
&beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
return false;
@ -192,13 +193,14 @@ class convCudnn : public Kernel {
}
void compute(const Operator &op, const RuntimeObj *context) const override {
ConvCuDnnPerfRecord record; // with paramters in default ctor
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
// default ctor
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
ConvCuDnnPerfRecord ret;
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<ConvObj>(_op);
@ -206,13 +208,14 @@ class convCudnn : public Kernel {
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
for (int algo = 0; algo < N_ALGO; algo++) {
ConvCuDnnPerfRecord record;
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
auto &record = *recordRef;
record.mode = mode;
record.algo = algo;
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
createCuDNNDescriptor(op, recordRef);
// get workspace
stat = cudnnGetConvolutionForwardWorkspaceSize(
@ -257,13 +260,13 @@ class convCudnn : public Kernel {
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return ret;
return make_ref<ConvCuDnnPerfRecordObj>(ret);
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
auto &record = dynamic_cast<const ConvCuDnnPerfRecord &>(_record);
auto record = as<ConvCuDnnPerfRecordObj>(_record);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = cuDNNUnfused(op, record, context);
IT_ASSERT(success);

View File

@ -1,6 +1,6 @@
#include "operators/element_wise.h"
#include "core/kernel.h"
#include "cuda/cuda_element_wise.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
@ -66,11 +66,9 @@ class ElementWiseCudnn : public Kernel {
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
return make_ref<PerfRecordObj>(timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); }));
}
};
@ -89,24 +87,10 @@ class MulCudnn : public ElementWiseCudnn {
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; }
};
class ElementWiseCuda : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
element_wise_kernel(_op);
}
class ElementWiseCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
element_wise_kernel(_op);
}
};

View File

@ -5,9 +5,10 @@
#include <functional>
namespace infini {
struct MatmulCudnnPerfRecord : public PerfRecord {
struct MatmulCudnnPerfRecordObj : public PerfRecordObj {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
};
using MatmulCudnnPerfRecord = Ref<MatmulCudnnPerfRecordObj>;
constexpr int N_ALGO = 24;
constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2,
@ -28,7 +29,7 @@ class matmulCublas : public Kernel {
void *const inAData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inBData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
auto record = dynamic_cast<const MatmulCudnnPerfRecord &>(_record);
auto record = as<MatmulCudnnPerfRecordObj>(_record);
const auto [b, m, n, k] = op->getBMNK();
auto opA =
@ -43,12 +44,12 @@ class matmulCublas : public Kernel {
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta,
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record.algo);
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record->algo);
} else {
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha, inBData, CUDA_R_32F, ldb, inAData,
CUDA_R_32F, lda, &beta, outData, CUDA_R_32F,
ldc, CUDA_R_32F, record.algo);
ldc, CUDA_R_32F, record->algo);
}
return (stat == CUBLAS_STATUS_SUCCESS);
}
@ -59,7 +60,8 @@ class matmulCublas : public Kernel {
}
void compute(const Operator &op, const RuntimeObj *context) const override {
MatmulCudnnPerfRecord record; // use default record;
auto record =
make_ref<MatmulCudnnPerfRecordObj>(); // use default record;
compute(op, record, context);
}
@ -67,21 +69,21 @@ class matmulCublas : public Kernel {
const RuntimeObj *_context) const override {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<MatmulObj>(_op);
MatmulCudnnPerfRecord ret;
ret.time = std::numeric_limits<double>::max();
auto ret = make_ref<MatmulCudnnPerfRecordObj>();
ret->time = std::numeric_limits<double>::max();
for (int i = 0; i < N_ALGO; i++) {
MatmulCudnnPerfRecord rcd;
rcd.algo = ALGOS[i];
auto rcd = make_ref<MatmulCudnnPerfRecordObj>();
rcd->algo = ALGOS[i];
if (!do_compute(_op, rcd, _context))
continue;
rcd.time = timeit([&]() { do_compute(_op, rcd, _context); },
[&]() { context->sync(); });
if (rcd.time < ret.time)
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
[&]() { context->sync(); });
if (rcd->time < ret->time)
ret = rcd;
}
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
IT_ASSERT(ret->time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return ret;
}
};

View File

@ -1,11 +1,11 @@
#include "operators/pooling.h"
#include "core/kernel.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class poolingCudnn : public Kernel {
class poolingCudnn : public CudaKernelWithoutConfig {
virtual cudnnPoolingMode_t getPoolingMode() const = 0;
void compute(const Operator &_op, const PerfRecord &record,
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
@ -54,20 +54,6 @@ class poolingCudnn : public Kernel {
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyPoolingDescriptor(poolingDesc));
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class maxPoolCudnn : public poolingCudnn {

View File

@ -1,35 +1,21 @@
#include "operators/unary.h"
#include "core/kernel.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_unary.h"
namespace infini {
class UnaryCuda : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
class UnaryCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
unary_kernel(_op);
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class ActivationCudnn : public Kernel {
class ActivationCudnn : public CudaKernelWithoutConfig {
virtual cudnnActivationMode_t getOpType() const = 0;
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
void compute(const Operator &_op, const PerfRecord &record,
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
@ -72,27 +58,13 @@ class ActivationCudnn : public Kernel {
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class SoftmaxCudnn : public Kernel {
class SoftmaxCudnn : public CudaKernelWithoutConfig {
virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0;
virtual cudnnSoftmaxMode_t getModeType() const = 0;
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
void compute(const Operator &_op, const PerfRecord &record,
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
@ -128,20 +100,6 @@ class SoftmaxCudnn : public Kernel {
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class ReluCudnn : public ActivationCudnn {