diff --git a/include/cuda/cuda_softmax.h b/include/cuda/cuda_softmax.h new file mode 100644 index 00000000..671f46f8 --- /dev/null +++ b/include/cuda/cuda_softmax.h @@ -0,0 +1,6 @@ +#pragma once +#include "utils/small_array.h" +namespace infini { +void softmax_kernel(int num_blocks, float *input, float *output, int size, + int dimsize, int stride); +} diff --git a/include/cuda/softmax.h b/include/cuda/softmax.h deleted file mode 100644 index 5c0eccf9..00000000 --- a/include/cuda/softmax.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -namespace infini { -void softmax_kernel(int max_threadblock_size, int batch_size, float *x, - float *y, int dim, int stride); -} diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 8975d7cd..3ea61b41 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -1,6 +1,7 @@ #include "operators/unary.h" #include "core/constants.h" #include "core/kernel.h" +#include "operators/softmax.h" namespace infini { template class NativeUnary : public CpuKernelWithoutConfig { @@ -22,7 +23,7 @@ template class NativeUnary : public CpuKernelWithoutConfig { template class NaiveSoftmax : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { - auto op = as(_op); + auto op = as(_op); T *inptr = op->getInputs(0)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); diff --git a/src/kernels/cuda/softmax.cc b/src/kernels/cuda/softmax.cc index 437ed849..024288c2 100644 --- a/src/kernels/cuda/softmax.cc +++ b/src/kernels/cuda/softmax.cc @@ -1,30 +1,30 @@ #include "operators/softmax.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" -#include "cuda/softmax.h" +#include "cuda/cuda_softmax.h" namespace infini { -class SoftmaxCudnn : public CudaKernelWithoutConfig { +class SoftmaxCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - auto x = op->getInputs(0)->getRawDataPtr(); - auto y = op->getOutput(0)->getRawDataPtr(); + auto input = op->getInputs(0)->getRawDataPtr(); + auto output = op->getOutput(0)->getRawDataPtr(); + const auto &inShape = op->getInputs(0)->getDims(); // input shape auto dims = op->getInputs(0)->getDims(); - int batch_size = 1; - for (size_t i = 0; i < dims.size(); ++i) - batch_size *= dims[i]; - int dim = dims[op->getAxis()]; + int size; // size = i(JKS) + j(KS) + k(S) + s + size = op->getOutput(0)->size(); + int dimsize = dims[op->getAxis()]; + int stride = op->getInputs(0)->getStride().at(op->getAxis()); - int block_num = batch_size / dim; - int max_threadblock_size = batch_size / block_num; - softmax_kernel(max_threadblock_size, block_num, x, y, dim, - op->getInputs(0)->getStride().at(op->getAxis())); + int num_blocks = size / dimsize; + softmax_kernel(num_blocks, (float *)input, (float *)output, size, + dimsize, stride); } }; -REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCudnn, +REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda, "Softmax_CUDA_Float32"); } // namespace infini diff --git a/src/kernels/cuda/softmax.cu b/src/kernels/cuda/softmax.cu index 1f7f39e6..7e85ec43 100644 --- a/src/kernels/cuda/softmax.cu +++ b/src/kernels/cuda/softmax.cu @@ -1,77 +1,183 @@ #include "cuda/cuda_common.h" -#include "cuda/softmax.h" #include -struct __align__(8) MD { - float data; - float d; +struct __align__(8) DataMaxSum { // update the global max and sum, store the + // output at max_tmp and sum_tmp + float max_tmp; // store max + float sum_tmp; // store sum +}; +__device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a, + DataMaxSum b) { + bool a_bigger = (a.max_tmp > b.max_tmp); + DataMaxSum bigger = a_bigger ? a : b; + DataMaxSum smaller = a_bigger ? b : a; + bigger.sum_tmp = bigger.sum_tmp + + smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp); + + return bigger; +} +template +__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( + float *__restrict input, float *__restrict output, int size, int dimsize, + int stride) { // if set axis = 1, inputShape=[I,J,K,S] + // tid = i(JKS) + j(KS) + k(S) + s + + // blockDim.x = size/dimsize = IKS + // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s + + int tid = + blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) * + dimsize; // now, tid = i(JKS) + k(S) + s; + + DataMaxSum dms_partial; + dms_partial.max_tmp = -__FLT_MAX__; + dms_partial.sum_tmp = 0.0f; + DataMaxSum dms_input; + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + dms_input.max_tmp = + input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; + + dms_input.sum_tmp = 1.0f; + dms_partial = reduce_dms_op(dms_partial, + dms_input); // reduce the data to one block + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ DataMaxSum dms_total; + DataMaxSum dms_block = + BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + dms_total = dms_block; + } + __syncthreads(); + //----------------- + + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + __expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - + dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp); + } +} + +template struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } }; -__device__ __forceinline__ MD reduce_md_op(MD a, MD b) { - bool a_bigger = (a.data > b.data); - MD bigger_m = a_bigger ? a : b; - MD smaller_m = a_bigger ? b : a; - MD res; - res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.data - bigger_m.data); - res.data = bigger_m.data; - return res; -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void online_softmax(const float *__restrict in, float *__restrict out, - int dimSize, int stride) { - - // reposition in and out to data for the current vector - int blockOffset = blockIdx.x; - if (blockIdx.x >= stride) { - int tmp = blockIdx.x % stride; - blockOffset = tmp + (blockIdx.x - tmp) * dimSize; +template struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); } - in += blockOffset; - out += blockOffset; - - MD md_partial; - md_partial.data = -FLT_MAX; - md_partial.d = 0.0F; - - for (int elem_id = threadIdx.x; elem_id < dimSize; - elem_id += THREADBLOCK_SIZE) { - MD new_elem; - new_elem.data = in[elem_id * stride]; - new_elem.d = 1.0F; - md_partial = reduce_md_op(md_partial, new_elem); +}; +template