forked from jiuyuan/InfiniTensor
Fix: PerfRecord in shared pointers (#31)
* Fix: PerfData in a shared pointer * Add: abstraction for kernels without configuration Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
6ac106cba4
commit
d39328afce
|
@ -34,4 +34,7 @@
|
|||
build/
|
||||
build_debug/
|
||||
|
||||
.vscode/
|
||||
.vscode/
|
||||
|
||||
# python
|
||||
*.pyc
|
|
@ -7,13 +7,14 @@ namespace infini {
|
|||
|
||||
class RuntimeObj; // Forward declaration for Kernel::compute
|
||||
|
||||
struct PerfRecord {
|
||||
PerfRecord(){};
|
||||
PerfRecord(double time) : time(time){};
|
||||
virtual ~PerfRecord() {}
|
||||
struct PerfRecordObj {
|
||||
PerfRecordObj(){};
|
||||
PerfRecordObj(double time) : time(time){};
|
||||
virtual ~PerfRecordObj() {}
|
||||
|
||||
double time = 0; // in milliseconds
|
||||
};
|
||||
using PerfRecord = Ref<PerfRecordObj>;
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
|
@ -73,6 +74,21 @@ class KernelRegistry {
|
|||
}
|
||||
};
|
||||
|
||||
class CpuKernelWithoutConfig : public Kernel {
|
||||
public:
|
||||
void compute(const Operator &op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const override {
|
||||
compute(op, context);
|
||||
}
|
||||
virtual void compute(const Operator &op,
|
||||
const RuntimeObj *context) const = 0;
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, context); }));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
||||
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class CudaKernelWithoutConfig : public Kernel {
|
||||
public:
|
||||
virtual void compute(const Operator &op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const {
|
||||
compute(op, context);
|
||||
}
|
||||
virtual void compute(const Operator &op,
|
||||
const RuntimeObj *context) const = 0;
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *_context) const {
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
|
||||
[&]() { context->sync(); }));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -83,7 +83,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
} else
|
||||
record = *perfData;
|
||||
|
||||
double t = record.time;
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
if (profiling) {
|
||||
op->print();
|
||||
|
|
|
@ -30,7 +30,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
} else
|
||||
record = *perfData;
|
||||
|
||||
double t = record.time;
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
|
||||
if (profiling) {
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveConv : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
|
@ -45,15 +45,6 @@ template <typename T> class NaiveConv : public Kernel {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
return PerfRecord(timeit([&]() { compute(op, context); }));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32,
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativeElementWise : public Kernel {
|
||||
template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
|
||||
virtual T doCompute(T val0, T val1) const = 0;
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
|
@ -24,16 +24,6 @@ template <typename T> class NativeElementWise : public Kernel {
|
|||
outptr[offset] = doCompute(inptr0[offset], inptr1[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveAdd : public NativeElementWise<T> {
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveMatmul : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
|
||||
|
@ -24,17 +24,6 @@ template <typename T> class NaiveMatmul : public Kernel {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord ret;
|
||||
ret.time = timeit([&]() { compute(op, context); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::UInt32,
|
||||
|
|
|
@ -73,7 +73,7 @@ class MemboundInterpreter : public Kernel {
|
|||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
return PerfRecord(
|
||||
return make_ref<PerfRecordObj>(
|
||||
timeit([&]() { compute(op, context); }, []() {}, 0, 1));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativePooling : public Kernel {
|
||||
template <typename T> class NativePooling : public CpuKernelWithoutConfig {
|
||||
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih,
|
||||
int iw, T *inptr) const = 0;
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
|
@ -32,16 +32,6 @@ template <typename T> class NativePooling : public Kernel {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveMaxPool : public NativePooling<T> {
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativeUnary : public Kernel {
|
||||
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
|
||||
virtual T doCompute(T val) const = 0;
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
|
@ -17,20 +17,10 @@ template <typename T> class NativeUnary : public Kernel {
|
|||
outptr[offset] = doCompute(inptr[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSoftmax : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
|
@ -46,16 +36,6 @@ template <typename T> class NaiveSoftmax : public Kernel {
|
|||
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveRelu : public NativeUnary<T> {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "operators/G2BMM.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "custom_ops.h"
|
||||
#include <chrono>
|
||||
|
@ -7,7 +7,7 @@
|
|||
#include <tuple>
|
||||
namespace infini {
|
||||
|
||||
class G2BMMCudnn : public Kernel {
|
||||
class G2BMMCudnn : public CudaKernelWithoutConfig {
|
||||
|
||||
bool g2bmmKernel(const Ref<G2BMMObj> &op,
|
||||
const CudaRuntimeObj *context) const {
|
||||
|
@ -25,31 +25,27 @@ class G2BMMCudnn : public Kernel {
|
|||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
PerfRecord record;
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord record;
|
||||
auto op = as<G2BMMObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
record.time = std::numeric_limits<double>::max();
|
||||
auto record =
|
||||
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
|
||||
const auto [warmupRounds, timingRounds] =
|
||||
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||
double tmp =
|
||||
timeit([&]() { g2bmmKernel(op, context); },
|
||||
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||
if (tmp < record.time)
|
||||
record.time = tmp;
|
||||
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
|
||||
if (tmp < record->time)
|
||||
record->time = tmp;
|
||||
IT_ASSERT(record->time < std::numeric_limits<double>::max(),
|
||||
"Error occured "
|
||||
"during runtime");
|
||||
return record;
|
||||
}
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<G2BMMObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "operators/GBMM.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "custom_ops.h"
|
||||
#include <chrono>
|
||||
|
@ -8,7 +8,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
class GBMMCudnn : public Kernel {
|
||||
class GBMMCudnn : public CudaKernelWithoutConfig {
|
||||
|
||||
bool gbmmKernel(const Ref<GBMMObj> &op,
|
||||
const CudaRuntimeObj *context) const {
|
||||
|
@ -25,32 +25,28 @@ class GBMMCudnn : public Kernel {
|
|||
// checkCudaError(cudaDeviceSynchronize());
|
||||
return true;
|
||||
}
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
PerfRecord record;
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord record;
|
||||
auto op = as<GBMMObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
record.time = std::numeric_limits<double>::max();
|
||||
auto record =
|
||||
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
|
||||
const auto [warmupRounds, timingRounds] =
|
||||
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||
double tmp =
|
||||
timeit([&]() { gbmmKernel(op, context); },
|
||||
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||
if (tmp < record.time)
|
||||
record.time = tmp;
|
||||
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
|
||||
if (tmp < record->time)
|
||||
record->time = tmp;
|
||||
IT_ASSERT(record->time < std::numeric_limits<double>::max(),
|
||||
"Error occured "
|
||||
"during runtime");
|
||||
return record;
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GBMMObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
|
|
@ -21,12 +21,13 @@ static constexpr int N_MODE = 2;
|
|||
static constexpr cudnnConvolutionMode_t MODES[N_MODE] = {
|
||||
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||
|
||||
struct ConvCuDnnPerfRecord : public PerfRecord {
|
||||
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
||||
int mode = 1;
|
||||
size_t workspaceSize = 100000;
|
||||
bool fuseAct = false;
|
||||
};
|
||||
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
|
||||
|
||||
class convCudnn : public Kernel {
|
||||
|
||||
|
@ -73,7 +74,7 @@ class convCudnn : public Kernel {
|
|||
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
||||
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||
convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode],
|
||||
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
|
||||
CUDNN_DATA_FLOAT));
|
||||
if (g > 1) {
|
||||
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
||||
|
@ -125,13 +126,13 @@ class convCudnn : public Kernel {
|
|||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
size_t wsSize = record.workspaceSize;
|
||||
size_t wsSize = record->workspaceSize;
|
||||
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
|
||||
inData, knDesc, knData, convDesc,
|
||||
ALGOS[record.algo], wsData, wsSize,
|
||||
ALGOS[record->algo], wsData, wsSize,
|
||||
&beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return false;
|
||||
|
@ -192,13 +193,14 @@ class convCudnn : public Kernel {
|
|||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
ConvCuDnnPerfRecord record; // with paramters in default ctor
|
||||
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
|
||||
// default ctor
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvCuDnnPerfRecord ret;
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
|
@ -206,13 +208,14 @@ class convCudnn : public Kernel {
|
|||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
ConvCuDnnPerfRecord record;
|
||||
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
|
||||
auto &record = *recordRef;
|
||||
record.mode = mode;
|
||||
record.algo = algo;
|
||||
cudnnStatus_t stat;
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
createCuDNNDescriptor(op, recordRef);
|
||||
|
||||
// get workspace
|
||||
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
|
@ -257,13 +260,13 @@ class convCudnn : public Kernel {
|
|||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return ret;
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto &record = dynamic_cast<const ConvCuDnnPerfRecord &>(_record);
|
||||
auto record = as<ConvCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, record, context);
|
||||
IT_ASSERT(success);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_element_wise.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -66,11 +66,9 @@ class ElementWiseCudnn : public Kernel {
|
|||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); }));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -89,24 +87,10 @@ class MulCudnn : public ElementWiseCudnn {
|
|||
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; }
|
||||
};
|
||||
|
||||
class ElementWiseCuda : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
element_wise_kernel(_op);
|
||||
}
|
||||
|
||||
class ElementWiseCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
element_wise_kernel(_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
#include <functional>
|
||||
|
||||
namespace infini {
|
||||
struct MatmulCudnnPerfRecord : public PerfRecord {
|
||||
struct MatmulCudnnPerfRecordObj : public PerfRecordObj {
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
};
|
||||
using MatmulCudnnPerfRecord = Ref<MatmulCudnnPerfRecordObj>;
|
||||
constexpr int N_ALGO = 24;
|
||||
constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
||||
CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2,
|
||||
|
@ -28,7 +29,7 @@ class matmulCublas : public Kernel {
|
|||
void *const inAData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inBData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto record = dynamic_cast<const MatmulCudnnPerfRecord &>(_record);
|
||||
auto record = as<MatmulCudnnPerfRecordObj>(_record);
|
||||
|
||||
const auto [b, m, n, k] = op->getBMNK();
|
||||
auto opA =
|
||||
|
@ -43,12 +44,12 @@ class matmulCublas : public Kernel {
|
|||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta,
|
||||
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record.algo);
|
||||
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha, inBData, CUDA_R_32F, ldb, inAData,
|
||||
CUDA_R_32F, lda, &beta, outData, CUDA_R_32F,
|
||||
ldc, CUDA_R_32F, record.algo);
|
||||
ldc, CUDA_R_32F, record->algo);
|
||||
}
|
||||
return (stat == CUBLAS_STATUS_SUCCESS);
|
||||
}
|
||||
|
@ -59,7 +60,8 @@ class matmulCublas : public Kernel {
|
|||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
MatmulCudnnPerfRecord record; // use default record;
|
||||
auto record =
|
||||
make_ref<MatmulCudnnPerfRecordObj>(); // use default record;
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
|
@ -67,21 +69,21 @@ class matmulCublas : public Kernel {
|
|||
const RuntimeObj *_context) const override {
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<MatmulObj>(_op);
|
||||
MatmulCudnnPerfRecord ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto ret = make_ref<MatmulCudnnPerfRecordObj>();
|
||||
ret->time = std::numeric_limits<double>::max();
|
||||
for (int i = 0; i < N_ALGO; i++) {
|
||||
MatmulCudnnPerfRecord rcd;
|
||||
rcd.algo = ALGOS[i];
|
||||
auto rcd = make_ref<MatmulCudnnPerfRecordObj>();
|
||||
rcd->algo = ALGOS[i];
|
||||
if (!do_compute(_op, rcd, _context))
|
||||
continue;
|
||||
rcd.time = timeit([&]() { do_compute(_op, rcd, _context); },
|
||||
[&]() { context->sync(); });
|
||||
if (rcd.time < ret.time)
|
||||
rcd->time = timeit([&]() { do_compute(_op, rcd, _context); },
|
||||
[&]() { context->sync(); });
|
||||
if (rcd->time < ret->time)
|
||||
ret = rcd;
|
||||
}
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
IT_ASSERT(ret->time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
#include "operators/pooling.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class poolingCudnn : public Kernel {
|
||||
class poolingCudnn : public CudaKernelWithoutConfig {
|
||||
virtual cudnnPoolingMode_t getPoolingMode() const = 0;
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
@ -54,20 +54,6 @@ class poolingCudnn : public Kernel {
|
|||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyPoolingDescriptor(poolingDesc));
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class maxPoolCudnn : public poolingCudnn {
|
||||
|
|
|
@ -1,35 +1,21 @@
|
|||
#include "operators/unary.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class UnaryCuda : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
class UnaryCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
unary_kernel(_op);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class ActivationCudnn : public Kernel {
|
||||
class ActivationCudnn : public CudaKernelWithoutConfig {
|
||||
virtual cudnnActivationMode_t getOpType() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
@ -72,27 +58,13 @@ class ActivationCudnn : public Kernel {
|
|||
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class SoftmaxCudnn : public Kernel {
|
||||
class SoftmaxCudnn : public CudaKernelWithoutConfig {
|
||||
virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0;
|
||||
virtual cudnnSoftmaxMode_t getModeType() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
@ -128,20 +100,6 @@ class SoftmaxCudnn : public Kernel {
|
|||
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class ReluCudnn : public ActivationCudnn {
|
||||
|
|
Loading…
Reference in New Issue