Cuda softmax (#129)

* "add softmax.cu,.cc,.h"

* Modify cuda softmax

* "modified the introduction of softmax.cu"

* "add format of cuda_softmax.h"

* "modified where.cc(.cu,.h) and softmax.cu"

* "modified format"

* Fix cpu softmax kernel

* "modified the // introduction of softmax.cu"

* "modified softmax.cu and use 1D block"

* "modified softmax.cu,format, and use 1D block"

* "introduce share mem to speed softmax"

* "reduce the input of function"

* modified the format

* remodify 2D block softmax

* remodify 1D block softmax

* modified the share memory

* add warp reduce

* conflict solve two

* remove extra space line

* solve comment

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
xgqdut2016 2023-11-06 08:56:23 +08:00 committed by GitHub
parent 1a6fccccbe
commit d3e7543291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 102 deletions

View File

@ -0,0 +1,6 @@
#pragma once
#include "utils/small_array.h"
namespace infini {
void softmax_kernel(int num_blocks, float *input, float *output, int size,
int dimsize, int stride);
}

View File

@ -1,6 +0,0 @@
#pragma once
namespace infini {
void softmax_kernel(int max_threadblock_size, int batch_size, float *x,
float *y, int dim, int stride);
}

View File

@ -1,6 +1,7 @@
#include "operators/unary.h" #include "operators/unary.h"
#include "core/constants.h" #include "core/constants.h"
#include "core/kernel.h" #include "core/kernel.h"
#include "operators/softmax.h"
namespace infini { namespace infini {
template <typename T> class NativeUnary : public CpuKernelWithoutConfig { template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
@ -22,7 +23,7 @@ template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig { template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *context) const override { const RuntimeObj *context) const override {
auto op = as<UnaryObj>(_op); auto op = as<SoftmaxObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>(); T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>(); T *outptr = op->getOutput()->getRawDataPtr<T *>();

View File

@ -1,30 +1,30 @@
#include "operators/softmax.h" #include "operators/softmax.h"
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/softmax.h" #include "cuda/cuda_softmax.h"
namespace infini { namespace infini {
class SoftmaxCudnn : public CudaKernelWithoutConfig { class SoftmaxCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<SoftmaxObj>(_op); auto op = as<SoftmaxObj>(_op);
auto x = op->getInputs(0)->getRawDataPtr<float *>(); auto input = op->getInputs(0)->getRawDataPtr<float *>();
auto y = op->getOutput(0)->getRawDataPtr<float *>(); auto output = op->getOutput(0)->getRawDataPtr<float *>();
const auto &inShape = op->getInputs(0)->getDims(); // input shape
auto dims = op->getInputs(0)->getDims(); auto dims = op->getInputs(0)->getDims();
int batch_size = 1; int size; // size = i(JKS) + j(KS) + k(S) + s
for (size_t i = 0; i < dims.size(); ++i) size = op->getOutput(0)->size();
batch_size *= dims[i]; int dimsize = dims[op->getAxis()];
int dim = dims[op->getAxis()]; int stride = op->getInputs(0)->getStride().at(op->getAxis());
int block_num = batch_size / dim; int num_blocks = size / dimsize;
int max_threadblock_size = batch_size / block_num; softmax_kernel(num_blocks, (float *)input, (float *)output, size,
softmax_kernel(max_threadblock_size, block_num, x, y, dim, dimsize, stride);
op->getInputs(0)->getStride().at(op->getAxis()));
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda,
"Softmax_CUDA_Float32"); "Softmax_CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -1,77 +1,183 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/softmax.h"
#include <cub/cub.cuh> #include <cub/cub.cuh>
struct __align__(8) MD { struct __align__(8) DataMaxSum { // update the global max and sum, store the
float data; // output at max_tmp and sum_tmp
float d; float max_tmp; // store max
float sum_tmp; // store sum
};
__device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a,
DataMaxSum b) {
bool a_bigger = (a.max_tmp > b.max_tmp);
DataMaxSum bigger = a_bigger ? a : b;
DataMaxSum smaller = a_bigger ? b : a;
bigger.sum_tmp = bigger.sum_tmp +
smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp);
return bigger;
}
template <int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
float *__restrict input, float *__restrict output, int size, int dimsize,
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
// tid = i(JKS) + j(KS) + k(S) + s
// blockDim.x = size/dimsize = IKS
// blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s
int tid =
blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
dimsize; // now, tid = i(JKS) + k(S) + s;
DataMaxSum dms_partial;
dms_partial.max_tmp = -__FLT_MAX__;
dms_partial.sum_tmp = 0.0f;
DataMaxSum dms_input;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
dms_input.max_tmp =
input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
dms_input.sum_tmp = 1.0f;
dms_partial = reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ DataMaxSum dms_total;
DataMaxSum dms_block =
BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
dms_total = dms_block;
}
__syncthreads();
//-----------------
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
}
template <typename T> struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return a + b;
}
}; };
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) { template <typename T> struct MaxOp {
bool a_bigger = (a.data > b.data); __device__ __forceinline__ T operator()(const T &a, const T &b) const {
MD bigger_m = a_bigger ? a : b; return max(a, b);
MD smaller_m = a_bigger ? b : a; }
MD res; };
res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.data - bigger_m.data); template <template <typename> class ReductionOp, typename T,
res.data = bigger_m.data; int thread_group_width>
return res; __inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
} }
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void _warpSoftmaxKernel(float *__restrict input,
float *__restrict output, int size,
int dimsize, int stride) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int otherSize = size / dimsize;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
template <int THREADBLOCK_SIZE> if (otherIdx < otherSize) {
__launch_bounds__(THREADBLOCK_SIZE) __global__
void online_softmax(const float *__restrict in, float *__restrict out,
int dimSize, int stride) {
// reposition in and out to data for the current vector __shared__ float max_total[BLOCK_DIM_y];
int blockOffset = blockIdx.x; __shared__ float sum_total[BLOCK_DIM_y];
if (blockIdx.x >= stride) { float max_data = -__FLT_MAX__;
int tmp = blockIdx.x % stride;
blockOffset = tmp + (blockIdx.x - tmp) * dimSize;
}
in += blockOffset;
out += blockOffset;
MD md_partial; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
md_partial.data = -FLT_MAX; max_data =
md_partial.d = 0.0F; max(max_data,
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]);
for (int elem_id = threadIdx.x; elem_id < dimSize;
elem_id += THREADBLOCK_SIZE) {
MD new_elem;
new_elem.data = in[elem_id * stride];
new_elem.d = 1.0F;
md_partial = reduce_md_op(md_partial, new_elem);
} }
// blockreduce for THREADBLOCK_SIZE threads. max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data);
// The actrual threads num used in the block is "dimsSize"
typedef cub::BlockReduce<MD, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ MD md_total;
MD md = BlockReduce(temp_storage).Reduce(md_partial, reduce_md_op);
if (threadIdx.x == 0) if (threadIdx.x == 0)
md_total = md; max_total[threadIdx.y] = max_data;
__syncthreads();
float d_total_inverse = __fdividef(1.0F, md_total.d); //--------------------------------------------
for (int elem_id = threadIdx.x; elem_id < dimSize; float sum_data = 0.0f;
elem_id += THREADBLOCK_SIZE)
out[elem_id * stride] = for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
__expf(in[elem_id * stride] - md_total.data) * d_total_inverse; sum_data +=
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
max_total[threadIdx.y]);
}
sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data);
if (threadIdx.x == 0)
sum_total[threadIdx.y] = sum_data;
//--------------------------------------------
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
max_total[threadIdx.y]) *
__fdividef(1.0F, sum_total[threadIdx.y]);
}
}
} }
//-----------------
//-----------------
namespace infini { namespace infini {
void softmax_kernel(int max_threadblock_size, int blockNum, float *in, void softmax_kernel(int num_blocks, float *input, float *output, int size,
float *out, int dimSize, int stride) { int dimsize, int stride) {
if (max_threadblock_size >= 255)
online_softmax<256><<<blockNum, 256>>>(in, out, dimSize, stride); if (dimsize > 1024) {
else if (max_threadblock_size >= 128)
online_softmax<128><<<blockNum, 128>>>(in, out, dimSize, stride); int BLOCK_DIM = 1024;
else if (max_threadblock_size >= 64) _blockSoftmaxKernel<1024>
online_softmax<64><<<blockNum, 64>>>(in, out, dimSize, stride); <<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
else } else if (dimsize > 31) {
online_softmax<32><<<blockNum, 32>>>(in, out, dimSize, stride); int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<16, 64>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<8, 128>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<4, 256>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
}
} }
} // namespace infini } // namespace infini

View File

@ -8,38 +8,38 @@
namespace infini { namespace infini {
void test_where(const Shape &inputxshape, const vector<float> &inputxdata, void test_where(const Shape &inputXShape, const vector<float> &inputXData,
const Shape &inputyshape, const vector<float> &inputydata, const Shape &inputYShape, const vector<float> &inputYData,
const Shape &conditionshape, const Shape &conditionShape,
const vector<uint8_t> &conditiondata, const vector<uint8_t> &conditionData,
const vector<float> &ExpectData) { const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8); auto condition = gCpu->addTensor(conditionShape, DataType::UInt8);
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32); auto inputX = gCpu->addTensor(inputXShape, DataType::Float32);
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32); auto inputY = gCpu->addTensor(inputYShape, DataType::Float32);
gCpu->dataMalloc(); gCpu->dataMalloc();
condition->copyin(conditiondata); // condition->copyin(conditionData); //
inputx->copyin(inputxdata); inputX->copyin(inputXData);
inputy->copyin(inputydata); // inputY->copyin(inputYData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>(); auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime); Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto conditionGpu = gCuda->cloneTensor(condition); auto conditionGpu = gCuda->cloneTensor(condition);
auto inputxGpu = gCuda->cloneTensor(inputx); auto inputXGpu = gCuda->cloneTensor(inputX);
auto inputyGpu = gCuda->cloneTensor(inputy); auto inputYGpu = gCuda->cloneTensor(inputY);
auto op = gCuda->addOp<WhereObj>(inputxGpu, inputyGpu, conditionGpu, auto op = gCuda->addOp<WhereObj>(inputXGpu, inputYGpu, conditionGpu,
nullptr); // WhereObj nullptr); // WhereObj
gCuda->dataMalloc(); gCuda->dataMalloc();
conditionGpu->copyin(conditiondata); conditionGpu->copyin(conditionData);
inputxGpu->copyin(inputxdata); inputXGpu->copyin(inputXData);
inputyGpu->copyin(inputydata); inputYGpu->copyin(inputYData);
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move data from gpu to cpu auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData)); EXPECT_TRUE(oCpu->equalData(ExpectData));
} }