forked from jiuyuan/InfiniTensor
Operator serialization (#14)
Class "Cuda Runtime" fulfills function "tune" and adds corresponding testcase. *Add: convCudnn::tune, convCudnn::cuDNNdescriptorAccess. *Add: testcase tune. *Fix: a brief bug in CPU Runtime.
This commit is contained in:
parent
93f86d3f4d
commit
25f0c441d2
|
@ -14,7 +14,7 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=sign-compare")
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||||
|
|
||||||
|
|
|
@ -71,8 +71,9 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
double timeit(const std::function<void()> &func, int warmupRounds = 200,
|
double timeit(
|
||||||
int timingRounds = 200,
|
const std::function<void()> &func,
|
||||||
const std::function<void(void)> &sync = {});
|
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200,
|
||||||
|
int timingRounds = 200);
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -21,6 +21,7 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
workspace = alloc(workspaceSize);
|
workspace = alloc(workspaceSize);
|
||||||
}
|
}
|
||||||
virtual ~CudaRuntimeObj() {
|
virtual ~CudaRuntimeObj() {
|
||||||
|
dealloc(workspace);
|
||||||
checkCudnnError(cudnnDestroy(cudnn));
|
checkCudnnError(cudnnDestroy(cudnn));
|
||||||
checkCublasError(cublasDestroy(cublas));
|
checkCublasError(cublasDestroy(cublas));
|
||||||
}
|
}
|
||||||
|
@ -57,6 +58,6 @@ class CudaRuntimeObj : public RuntimeObj {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runWithoutSync(const Graph &graph) const;
|
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||||
};
|
};
|
||||||
} // namespace infini
|
} // namespace infini
|
|
@ -4,8 +4,9 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
double timeit(const std::function<void()> &func, int warmupRounds,
|
double timeit(const std::function<void()> &func,
|
||||||
int timingRounds, const std::function<void(void)> &sync) {
|
const std::function<void(void)> &sync, int warmupRounds,
|
||||||
|
int timingRounds) {
|
||||||
for (int i = 0; i < warmupRounds; ++i)
|
for (int i = 0; i < warmupRounds; ++i)
|
||||||
func();
|
func();
|
||||||
if (sync)
|
if (sync)
|
||||||
|
|
|
@ -40,17 +40,18 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||||
// Tune the kernel if there is no record
|
// Tune the kernel if there is no record
|
||||||
if (!perfData) {
|
if (!perfData) {
|
||||||
// TODO: record is not used
|
// TODO: record is not used
|
||||||
|
// printf("no record data\n");
|
||||||
record = kernel->tune(op, this);
|
record = kernel->tune(op, this);
|
||||||
perfEngine.setPerfData(perfKey, record);
|
perfEngine.setPerfData(perfKey, record);
|
||||||
} else
|
} else
|
||||||
record = *perfData;
|
record = *perfData;
|
||||||
|
|
||||||
if (!profiling) {
|
if (!profiling) {
|
||||||
kernel->compute(op, *perfData, this);
|
kernel->compute(op, record, this);
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
double t =
|
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||||
timeit([&]() { kernel->compute(op, *perfData, this); }, 1, 1);
|
[]() {}, 1, 1);
|
||||||
op->print();
|
op->print();
|
||||||
printf(" op_time %lf\n", t);
|
printf(" op_time %lf\n", t);
|
||||||
totalTime += t;
|
totalTime += t;
|
||||||
|
|
|
@ -4,10 +4,13 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||||
|
bool profiling = false) const {
|
||||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||||
auto perfEngine = PerfEngine::getInstance();
|
auto perfEngine = PerfEngine::getInstance();
|
||||||
|
double totalTime = 0;
|
||||||
|
std::map<OpType, double> opTime;
|
||||||
|
std::map<OpType, int> opCnt;
|
||||||
for (auto &op : graph->getOperators()) {
|
for (auto &op : graph->getOperators()) {
|
||||||
// HACK: set correct data type
|
// HACK: set correct data type
|
||||||
auto kernelAttrs =
|
auto kernelAttrs =
|
||||||
|
@ -15,17 +18,40 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
||||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||||
if (perfData)
|
if (!perfData && !tune) {
|
||||||
kernel->compute(op, *perfData, this);
|
|
||||||
else
|
|
||||||
kernel->compute(op, this);
|
kernel->compute(op, this);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
PerfRecord record;
|
||||||
|
|
||||||
|
if (!perfData) {
|
||||||
|
record = kernel->tune(op, this);
|
||||||
|
perfEngine.setPerfData(perfKey, record);
|
||||||
|
|
||||||
|
} else
|
||||||
|
record = *perfData;
|
||||||
|
|
||||||
|
double t = record.time;
|
||||||
|
totalTime += t;
|
||||||
|
|
||||||
|
if (profiling) {
|
||||||
|
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||||
|
[&]() { sync(); }, 1, 1);
|
||||||
|
op->print();
|
||||||
|
printf(" op_time on cuda %lf\n", t);
|
||||||
|
totalTime += t;
|
||||||
|
opTime[op->getOpType()] += t;
|
||||||
|
opCnt[op->getOpType()]++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||||
if (tune || profiling)
|
if (profiling)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
runWithoutSync(graph);
|
|
||||||
|
runWithoutSync(graph, tune);
|
||||||
sync();
|
sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <functional>
|
||||||
|
#include <limits>
|
||||||
|
#include <tuple>
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
static constexpr int N_ALGO = 8;
|
static constexpr int N_ALGO = 8;
|
||||||
|
@ -14,18 +17,25 @@ static constexpr cudnnConvolutionFwdAlgo_t ALGOS[N_ALGO] = {
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
|
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
|
||||||
|
static constexpr int N_MODE = 2;
|
||||||
|
static constexpr cudnnConvolutionMode_t MODES[N_MODE] = {
|
||||||
|
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||||
|
|
||||||
struct ConvCuDnnPerfRecord : public PerfRecord {
|
struct ConvCuDnnPerfRecord : public PerfRecord {
|
||||||
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
||||||
|
int mode = 1;
|
||||||
size_t workspaceSize = 100000;
|
size_t workspaceSize = 100000;
|
||||||
bool fuseAct = false;
|
bool fuseAct = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class convCudnn : public Kernel {
|
class convCudnn : public Kernel {
|
||||||
|
|
||||||
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
|
||||||
const CudaRuntimeObj *context) const {
|
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||||
cudnnStatus_t stat;
|
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||||
|
cudnnTensorDescriptor_t>
|
||||||
|
cuDNNDescriptorAccess(const Ref<ConvObj> &op,
|
||||||
|
const ConvCuDnnPerfRecord &record) const {
|
||||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||||
if (op->getInputs().size() > 2) // Bias is not supported yet
|
if (op->getInputs().size() > 2) // Bias is not supported yet
|
||||||
|
@ -52,7 +62,6 @@ class convCudnn : public Kernel {
|
||||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||||
CUDNN_TENSOR_NCHW, f,
|
CUDNN_TENSOR_NCHW, f,
|
||||||
channelsPerGrp, r, s));
|
channelsPerGrp, r, s));
|
||||||
|
|
||||||
// get bias
|
// get bias
|
||||||
cudnnTensorDescriptor_t biasDesc;
|
cudnnTensorDescriptor_t biasDesc;
|
||||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||||
|
@ -64,7 +73,7 @@ class convCudnn : public Kernel {
|
||||||
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||||
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
||||||
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||||
convDesc, ph, pw, sh, sw, dh, dw, CUDNN_CROSS_CORRELATION,
|
convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode],
|
||||||
CUDNN_DATA_FLOAT));
|
CUDNN_DATA_FLOAT));
|
||||||
if (g > 1) {
|
if (g > 1) {
|
||||||
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
||||||
|
@ -93,7 +102,6 @@ class convCudnn : public Kernel {
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get outputs
|
|
||||||
int outn, outc, outh, outw;
|
int outn, outc, outh, outw;
|
||||||
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
|
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
|
||||||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||||
|
@ -106,6 +114,15 @@ class convCudnn : public Kernel {
|
||||||
op->getOutput()->getDims(),
|
op->getOutput()->getDims(),
|
||||||
"cuDNN output shape mismatches with OP output shape");
|
"cuDNN output shape mismatches with OP output shape");
|
||||||
|
|
||||||
|
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||||
|
convDesc, actDesc, outDesc);
|
||||||
|
}
|
||||||
|
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
||||||
|
const CudaRuntimeObj *context) const {
|
||||||
|
cudnnStatus_t stat;
|
||||||
|
|
||||||
|
auto [inData, knData, outData, inDesc, knDesc, biasDesc, convDesc,
|
||||||
|
actDesc, outDesc] = cuDNNDescriptorAccess(op, record);
|
||||||
// get workspace
|
// get workspace
|
||||||
size_t wsSize = record.workspaceSize;
|
size_t wsSize = record.workspaceSize;
|
||||||
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||||
|
@ -113,9 +130,7 @@ class convCudnn : public Kernel {
|
||||||
ALGOS[record.algo], &wsSize);
|
ALGOS[record.algo], &wsSize);
|
||||||
if (stat != CUDNN_STATUS_SUCCESS)
|
if (stat != CUDNN_STATUS_SUCCESS)
|
||||||
return false;
|
return false;
|
||||||
// assert(wsSize < (size_t)3 * 1024 * 1024 * 1024);
|
|
||||||
// if (wsSize >= (size_t)10 * 1024 * 1024 * 1024)
|
|
||||||
// continue;
|
|
||||||
CudaPtr wsData = context->getWorkspace(wsSize);
|
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||||
float alpha = 1.f, beta = 0.f;
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
|
||||||
|
@ -172,10 +187,12 @@ class convCudnn : public Kernel {
|
||||||
|
|
||||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||||
// whether sync is required before destories.
|
// whether sync is required before destories.
|
||||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
|
||||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||||
|
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,9 +203,63 @@ class convCudnn : public Kernel {
|
||||||
|
|
||||||
PerfRecord tune(const Operator &_op,
|
PerfRecord tune(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
// TODO: real tuning
|
ConvCuDnnPerfRecord ret, tmp_ret;
|
||||||
ConvCuDnnPerfRecord ret;
|
ret.time = std::numeric_limits<double>::max();
|
||||||
ret.time = timeit([&]() { compute(_op, _context); });
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
auto op = as<ConvObj>(_op);
|
||||||
|
// Try every possible data input mode of convolution func
|
||||||
|
for (int i = 0; i < N_MODE; i++) {
|
||||||
|
// Try every possible algorithm of convolution func
|
||||||
|
for (int j = 0; j < N_ALGO; j++) {
|
||||||
|
tmp_ret.algo = j;
|
||||||
|
tmp_ret.mode = i;
|
||||||
|
// Check if the kernel supports the op
|
||||||
|
cudnnStatus_t stat;
|
||||||
|
auto [inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||||
|
convDesc, actDesc, outDesc] =
|
||||||
|
cuDNNDescriptorAccess(op, tmp_ret);
|
||||||
|
|
||||||
|
// get workspace
|
||||||
|
size_t wsSize = tmp_ret.workspaceSize;
|
||||||
|
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||||
|
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
|
||||||
|
ALGOS[tmp_ret.algo], &wsSize);
|
||||||
|
if (stat != CUDNN_STATUS_SUCCESS)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||||
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
|
||||||
|
stat = cudnnConvolutionForward(
|
||||||
|
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||||
|
knData, convDesc, ALGOS[tmp_ret.algo], wsData, wsSize,
|
||||||
|
&beta, outDesc, outData);
|
||||||
|
if (stat != CUDNN_STATUS_SUCCESS)
|
||||||
|
continue;
|
||||||
|
tmp_ret.time = timeit(
|
||||||
|
[&]() {
|
||||||
|
cudnnConvolutionForward(
|
||||||
|
context->cudnnHandle(), &alpha, inDesc, inData,
|
||||||
|
knDesc, knData, convDesc, ALGOS[tmp_ret.algo],
|
||||||
|
wsData, wsSize, &beta, outDesc, outData);
|
||||||
|
},
|
||||||
|
[&]() { context->sync(); });
|
||||||
|
printf("mode:%d algo:%d :%.8lf\n", i, j, tmp_ret.time);
|
||||||
|
// Update the tune result
|
||||||
|
if (ret.time > tmp_ret.time)
|
||||||
|
ret = tmp_ret;
|
||||||
|
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||||
|
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||||
|
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||||
|
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Test infomation output
|
||||||
|
printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
|
||||||
|
ret.mode);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
#include "cuda/cuda_utility.h"
|
#include "cuda/cuda_utility.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
|
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -107,4 +109,30 @@ TEST(Conv, cuDNN) {
|
||||||
IncrementalGenerator(),
|
IncrementalGenerator(),
|
||||||
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Conv, tune) {
|
||||||
|
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||||
|
// Set input data on CPU in a CPU Graph
|
||||||
|
Tensor i0Cpu = gCpu->addTensor({1, 3, 800, 800}, DataType::Float32);
|
||||||
|
Tensor w0Cpu = gCpu->addTensor({2, 3, 5, 5}, DataType::Float32);
|
||||||
|
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
i0Cpu->setData(IncrementalGenerator());
|
||||||
|
w0Cpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Copy input tensors from CPU to CUDA
|
||||||
|
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||||
|
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||||
|
// Build CUDA graph
|
||||||
|
auto conv =
|
||||||
|
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
||||||
|
// allocate CUDA memory
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
// Execute on CUDA
|
||||||
|
bool tune = true;
|
||||||
|
cuda->run(gCuda, tune);
|
||||||
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
|
@ -1,2 +1,3 @@
|
||||||
|
. /home/spack/spack/share/spack/setup-env.sh
|
||||||
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
|
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
|
||||||
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc
|
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc
|
||||||
|
|
Loading…
Reference in New Issue