forked from jiuyuan/InfiniTensor
Fix kernel arguments, add debug mode (#119)
Add debug mode macro in cmakelist.
This commit is contained in:
commit
69fd251e5d
|
@ -15,6 +15,23 @@ cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
||||||
|
|
||||||
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
|
# Build Type
|
||||||
|
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
|
message("Configuring for Debug build.")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
|
||||||
|
add_compile_definitions(DEBUG_MODE)
|
||||||
|
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||||
|
message("Configuring for Release build.")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
|
||||||
|
add_compile_definitions(NDEBUG)
|
||||||
|
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
|
||||||
|
message("Configuring for RelWithDebInfo build.")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O2")
|
||||||
|
else()
|
||||||
|
message("Build type not specified. Configuring for RelWithDebInfo build.")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O2")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
|
if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
|
||||||
message(STATUS "Using config.cmake in CMAKE_CURRENT_BINARY_DIR directory")
|
message(STATUS "Using config.cmake in CMAKE_CURRENT_BINARY_DIR directory")
|
||||||
|
|
|
@ -4,6 +4,19 @@
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
|
||||||
|
#ifdef DEBUG_MODE
|
||||||
|
void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
|
||||||
|
cudaError_t kernelError = cudaGetLastError();
|
||||||
|
if (kernelError != cudaSuccess) {
|
||||||
|
std::cerr << "CUDA kernel error: " << cudaGetErrorString(kernelError)
|
||||||
|
<< std::endl
|
||||||
|
<< "Failed Operator: " << op->toString() << std::endl;
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
||||||
|
@ -22,6 +35,10 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
||||||
} else {
|
} else {
|
||||||
kernel->compute(op, this);
|
kernel->compute(op, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef DEBUG_MODE
|
||||||
|
CHECK_CUDA_KERNEL_ERROR(op);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,6 +74,10 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
||||||
opTime[op->getOpType()] += t;
|
opTime[op->getOpType()] += t;
|
||||||
opCnt[op->getOpType()]++;
|
opCnt[op->getOpType()]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef DEBUG_MODE
|
||||||
|
CHECK_CUDA_KERNEL_ERROR(op);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ void clip_kernel(float *input, float *output, int num, float minValue,
|
||||||
float maxValue) {
|
float maxValue) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_clip_kernel<<<blocksize, gridsize>>>(input, output, num, minValue,
|
_clip_kernel<<<gridsize, blocksize>>>(input, output, num, minValue,
|
||||||
maxValue);
|
maxValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ constexpr unsigned int num_threads() { return 32 * 4; }
|
||||||
constexpr int thread_work_size() { return 4; }
|
constexpr int thread_work_size() { return 4; }
|
||||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||||
|
|
||||||
__global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1, int a2, int a3,
|
__global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1,
|
||||||
int b0, int b1, int b2, int b3,
|
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||||
int c0, int c1, int c2, int c3) {
|
int c0, int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
@ -27,12 +27,15 @@ __global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1, int a2
|
||||||
int b1_index = c1_index % b1;
|
int b1_index = c1_index % b1;
|
||||||
int b2_index = c2_index % b2;
|
int b2_index = c2_index % b2;
|
||||||
int b3_index = c3_index % b3;
|
int b3_index = c3_index % b3;
|
||||||
z[i] = x[a0_index*a1*a2*a3 + a1_index*a2*a3 + a2_index*a3 + a3_index] / y[b0_index*b1*b2*b3 + b1_index*b2*b3 + b2_index*b3 + b3_index];
|
z[i] = x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 +
|
||||||
|
a3_index] /
|
||||||
|
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 +
|
||||||
|
b3_index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1, int a2, int a3,
|
__global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1,
|
||||||
int b0, int b1, int b2, int b3,
|
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||||
int c0, int c1, int c2, int c3) {
|
int c0, int c1, int c2, int c3) {
|
||||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
int stride = blockDim.x * gridDim.x;
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
@ -53,27 +56,32 @@ __global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1, int a2
|
||||||
int b1_index = c1_index % b1;
|
int b1_index = c1_index % b1;
|
||||||
int b2_index = c2_index % b2;
|
int b2_index = c2_index % b2;
|
||||||
int b3_index = c3_index % b3;
|
int b3_index = c3_index % b3;
|
||||||
z[i] = pow(x[a0_index*a1*a2*a3 + a1_index*a2*a3 + a2_index*a3 + a3_index], y[b0_index*b1*b2*b3 + b1_index*b2*b3 + b2_index*b3 + b3_index]);
|
z[i] = pow(x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||||
|
a2_index * a3 + a3_index],
|
||||||
|
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||||
|
b2_index * b3 + b3_index]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||||
int b0, int b1, int b2, int b3,
|
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||||
int c0, int c1, int c2, int c3) {
|
int c3) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_div_kernel<<<blocksize, gridsize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
_div_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||||
|
b3, c0, c1, c2, c3);
|
||||||
}
|
}
|
||||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||||
int b0, int b1, int b2, int b3,
|
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||||
int c0, int c1, int c2, int c3) {
|
int c3) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int num = c0 * c1 * c2 * c3;
|
int num = c0 * c1 * c2 * c3;
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_pow_kernel<<<blocksize, gridsize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
_pow_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||||
|
b3, c0, c1, c2, c3);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -19,7 +19,7 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
|
||||||
int oSize) {
|
int oSize) {
|
||||||
int blocksize = 32 * 16;
|
int blocksize = 32 * 16;
|
||||||
int gridsize = (oSize + blocksize - 1) / blocksize;
|
int gridsize = (oSize + blocksize - 1) / blocksize;
|
||||||
_extend_kernel<<<blocksize, gridsize>>>(in, out, blockSize, blockSizeOuter,
|
_extend_kernel<<<gridsize, blocksize>>>(in, out, blockSize, blockSizeOuter,
|
||||||
oSize);
|
oSize);
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -72,36 +72,36 @@ void softmax_kernel(float *input, float *output, int num) {
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_softmax_kernel1<<<1, 1>>>(input, output, num);
|
_softmax_kernel1<<<1, 1>>>(input, output, num);
|
||||||
_softmax_kernel2<<<blocksize, gridsize>>>(input, output, num);
|
_softmax_kernel2<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
void relu_kernel(float *input, float *output, int num) {
|
void relu_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_relu_kernel<<<blocksize, gridsize>>>(input, output, num);
|
_relu_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
void sigmoid_kernel(float *input, float *output, int num) {
|
void sigmoid_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_sigmoid_kernel<<<blocksize, gridsize>>>(input, output, num);
|
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
void tanh_kernel(float *input, float *output, int num) {
|
void tanh_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_tanh_kernel<<<blocksize, gridsize>>>(input, output, num);
|
_tanh_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
void abs_kernel(float *input, float *output, int num) {
|
void abs_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_abs_kernel<<<blocksize, gridsize>>>(input, output, num);
|
_abs_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
void sqrt_kernel(float *input, float *output, int num) {
|
void sqrt_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_sqrt_kernel<<<blocksize, gridsize>>>(input, output, num);
|
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue