forked from jiuyuan/InfiniTensor
Cuda softmax (#129)
* "add softmax.cu,.cc,.h" * Modify cuda softmax * "modified the introduction of softmax.cu" * "add format of cuda_softmax.h" * "modified where.cc(.cu,.h) and softmax.cu" * "modified format" * Fix cpu softmax kernel * "modified the // introduction of softmax.cu" * "modified softmax.cu and use 1D block" * "modified softmax.cu,format, and use 1D block" * "introduce share mem to speed softmax" * "reduce the input of function" * modified the format * remodify 2D block softmax * remodify 1D block softmax * modified the share memory * add warp reduce * conflict solve two * remove extra space line * solve comment --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com> Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
parent
1a6fccccbe
commit
d3e7543291
|
@ -0,0 +1,6 @@
|
||||||
|
#pragma once
|
||||||
|
#include "utils/small_array.h"
|
||||||
|
namespace infini {
|
||||||
|
void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
||||||
|
int dimsize, int stride);
|
||||||
|
}
|
|
@ -1,6 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace infini {
|
|
||||||
void softmax_kernel(int max_threadblock_size, int batch_size, float *x,
|
|
||||||
float *y, int dim, int stride);
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "core/constants.h"
|
#include "core/constants.h"
|
||||||
#include "core/kernel.h"
|
#include "core/kernel.h"
|
||||||
|
#include "operators/softmax.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
|
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
|
||||||
|
@ -22,7 +23,7 @@ template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
|
||||||
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
|
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *context) const override {
|
const RuntimeObj *context) const override {
|
||||||
auto op = as<UnaryObj>(_op);
|
auto op = as<SoftmaxObj>(_op);
|
||||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||||
|
|
||||||
|
|
|
@ -1,30 +1,30 @@
|
||||||
#include "operators/softmax.h"
|
#include "operators/softmax.h"
|
||||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
#include "cuda/softmax.h"
|
#include "cuda/cuda_softmax.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class SoftmaxCudnn : public CudaKernelWithoutConfig {
|
class SoftmaxCuda : public CudaKernelWithoutConfig {
|
||||||
|
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
auto op = as<SoftmaxObj>(_op);
|
auto op = as<SoftmaxObj>(_op);
|
||||||
auto x = op->getInputs(0)->getRawDataPtr<float *>();
|
auto input = op->getInputs(0)->getRawDataPtr<float *>();
|
||||||
auto y = op->getOutput(0)->getRawDataPtr<float *>();
|
auto output = op->getOutput(0)->getRawDataPtr<float *>();
|
||||||
|
const auto &inShape = op->getInputs(0)->getDims(); // input shape
|
||||||
auto dims = op->getInputs(0)->getDims();
|
auto dims = op->getInputs(0)->getDims();
|
||||||
|
|
||||||
int batch_size = 1;
|
int size; // size = i(JKS) + j(KS) + k(S) + s
|
||||||
for (size_t i = 0; i < dims.size(); ++i)
|
size = op->getOutput(0)->size();
|
||||||
batch_size *= dims[i];
|
int dimsize = dims[op->getAxis()];
|
||||||
int dim = dims[op->getAxis()];
|
int stride = op->getInputs(0)->getStride().at(op->getAxis());
|
||||||
|
|
||||||
int block_num = batch_size / dim;
|
int num_blocks = size / dimsize;
|
||||||
int max_threadblock_size = batch_size / block_num;
|
softmax_kernel(num_blocks, (float *)input, (float *)output, size,
|
||||||
softmax_kernel(max_threadblock_size, block_num, x, y, dim,
|
dimsize, stride);
|
||||||
op->getInputs(0)->getStride().at(op->getAxis()));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCudnn,
|
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda,
|
||||||
"Softmax_CUDA_Float32");
|
"Softmax_CUDA_Float32");
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,77 +1,183 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include "cuda/softmax.h"
|
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
|
|
||||||
struct __align__(8) MD {
|
struct __align__(8) DataMaxSum { // update the global max and sum, store the
|
||||||
float data;
|
// output at max_tmp and sum_tmp
|
||||||
float d;
|
float max_tmp; // store max
|
||||||
|
float sum_tmp; // store sum
|
||||||
|
};
|
||||||
|
__device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a,
|
||||||
|
DataMaxSum b) {
|
||||||
|
bool a_bigger = (a.max_tmp > b.max_tmp);
|
||||||
|
DataMaxSum bigger = a_bigger ? a : b;
|
||||||
|
DataMaxSum smaller = a_bigger ? b : a;
|
||||||
|
bigger.sum_tmp = bigger.sum_tmp +
|
||||||
|
smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp);
|
||||||
|
|
||||||
|
return bigger;
|
||||||
|
}
|
||||||
|
template <int BLOCK_DIM>
|
||||||
|
__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
|
||||||
|
float *__restrict input, float *__restrict output, int size, int dimsize,
|
||||||
|
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
|
||||||
|
// tid = i(JKS) + j(KS) + k(S) + s
|
||||||
|
|
||||||
|
// blockDim.x = size/dimsize = IKS
|
||||||
|
// blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s
|
||||||
|
|
||||||
|
int tid =
|
||||||
|
blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
|
||||||
|
dimsize; // now, tid = i(JKS) + k(S) + s;
|
||||||
|
|
||||||
|
DataMaxSum dms_partial;
|
||||||
|
dms_partial.max_tmp = -__FLT_MAX__;
|
||||||
|
dms_partial.sum_tmp = 0.0f;
|
||||||
|
DataMaxSum dms_input;
|
||||||
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
|
|
||||||
|
dms_input.max_tmp =
|
||||||
|
input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||||
|
|
||||||
|
dms_input.sum_tmp = 1.0f;
|
||||||
|
dms_partial = reduce_dms_op(dms_partial,
|
||||||
|
dms_input); // reduce the data to one block
|
||||||
|
}
|
||||||
|
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
__shared__ DataMaxSum dms_total;
|
||||||
|
DataMaxSum dms_block =
|
||||||
|
BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
|
||||||
|
if (threadIdx.x ==
|
||||||
|
0) { // must set threadIdx.x = 0 write the output to memory
|
||||||
|
dms_total = dms_block;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
//-----------------
|
||||||
|
|
||||||
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
|
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||||
|
dms_total.max_tmp) *
|
||||||
|
__fdividef(1.0F, dms_total.sum_tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> struct SumOp {
|
||||||
|
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
|
template <typename T> struct MaxOp {
|
||||||
bool a_bigger = (a.data > b.data);
|
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||||
MD bigger_m = a_bigger ? a : b;
|
return max(a, b);
|
||||||
MD smaller_m = a_bigger ? b : a;
|
}
|
||||||
MD res;
|
};
|
||||||
res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.data - bigger_m.data);
|
template <template <typename> class ReductionOp, typename T,
|
||||||
res.data = bigger_m.data;
|
int thread_group_width>
|
||||||
return res;
|
__inline__ __device__ T WarpAllReduce(T val) {
|
||||||
|
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||||
|
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||||
|
}
|
||||||
|
return val;
|
||||||
}
|
}
|
||||||
|
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||||
|
__global__ void _warpSoftmaxKernel(float *__restrict input,
|
||||||
|
float *__restrict output, int size,
|
||||||
|
int dimsize, int stride) {
|
||||||
|
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
int otherSize = size / dimsize;
|
||||||
|
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||||
|
|
||||||
template <int THREADBLOCK_SIZE>
|
if (otherIdx < otherSize) {
|
||||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
|
||||||
void online_softmax(const float *__restrict in, float *__restrict out,
|
|
||||||
int dimSize, int stride) {
|
|
||||||
|
|
||||||
// reposition in and out to data for the current vector
|
__shared__ float max_total[BLOCK_DIM_y];
|
||||||
int blockOffset = blockIdx.x;
|
__shared__ float sum_total[BLOCK_DIM_y];
|
||||||
if (blockIdx.x >= stride) {
|
float max_data = -__FLT_MAX__;
|
||||||
int tmp = blockIdx.x % stride;
|
|
||||||
blockOffset = tmp + (blockIdx.x - tmp) * dimSize;
|
|
||||||
}
|
|
||||||
in += blockOffset;
|
|
||||||
out += blockOffset;
|
|
||||||
|
|
||||||
MD md_partial;
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
md_partial.data = -FLT_MAX;
|
max_data =
|
||||||
md_partial.d = 0.0F;
|
max(max_data,
|
||||||
|
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]);
|
||||||
for (int elem_id = threadIdx.x; elem_id < dimSize;
|
|
||||||
elem_id += THREADBLOCK_SIZE) {
|
|
||||||
MD new_elem;
|
|
||||||
new_elem.data = in[elem_id * stride];
|
|
||||||
new_elem.d = 1.0F;
|
|
||||||
md_partial = reduce_md_op(md_partial, new_elem);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// blockreduce for THREADBLOCK_SIZE threads.
|
max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data);
|
||||||
// The actrual threads num used in the block is "dimsSize"
|
|
||||||
typedef cub::BlockReduce<MD, THREADBLOCK_SIZE> BlockReduce;
|
|
||||||
|
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
||||||
__shared__ MD md_total;
|
|
||||||
|
|
||||||
MD md = BlockReduce(temp_storage).Reduce(md_partial, reduce_md_op);
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
md_total = md;
|
max_total[threadIdx.y] = max_data;
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float d_total_inverse = __fdividef(1.0F, md_total.d);
|
//--------------------------------------------
|
||||||
for (int elem_id = threadIdx.x; elem_id < dimSize;
|
float sum_data = 0.0f;
|
||||||
elem_id += THREADBLOCK_SIZE)
|
|
||||||
out[elem_id * stride] =
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
__expf(in[elem_id * stride] - md_total.data) * d_total_inverse;
|
sum_data +=
|
||||||
|
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
|
max_total[threadIdx.y]);
|
||||||
|
}
|
||||||
|
|
||||||
|
sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
sum_total[threadIdx.y] = sum_data;
|
||||||
|
|
||||||
|
//--------------------------------------------
|
||||||
|
|
||||||
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
|
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||||
|
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
|
max_total[threadIdx.y]) *
|
||||||
|
__fdividef(1.0F, sum_total[threadIdx.y]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
//-----------------
|
||||||
|
|
||||||
|
//-----------------
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void softmax_kernel(int max_threadblock_size, int blockNum, float *in,
|
void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
||||||
float *out, int dimSize, int stride) {
|
int dimsize, int stride) {
|
||||||
if (max_threadblock_size >= 255)
|
|
||||||
online_softmax<256><<<blockNum, 256>>>(in, out, dimSize, stride);
|
if (dimsize > 1024) {
|
||||||
else if (max_threadblock_size >= 128)
|
|
||||||
online_softmax<128><<<blockNum, 128>>>(in, out, dimSize, stride);
|
int BLOCK_DIM = 1024;
|
||||||
else if (max_threadblock_size >= 64)
|
_blockSoftmaxKernel<1024>
|
||||||
online_softmax<64><<<blockNum, 64>>>(in, out, dimSize, stride);
|
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||||
else
|
} else if (dimsize > 31) {
|
||||||
online_softmax<32><<<blockNum, 32>>>(in, out, dimSize, stride);
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
_warpSoftmaxKernel<32, 32>
|
||||||
|
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
_warpSoftmaxKernel<16, 64>
|
||||||
|
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
_warpSoftmaxKernel<8, 128>
|
||||||
|
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
_warpSoftmaxKernel<4, 256>
|
||||||
|
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -8,38 +8,38 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
void test_where(const Shape &inputXShape, const vector<float> &inputXData,
|
||||||
const Shape &inputyshape, const vector<float> &inputydata,
|
const Shape &inputYShape, const vector<float> &inputYData,
|
||||||
const Shape &conditionshape,
|
const Shape &conditionShape,
|
||||||
const vector<uint8_t> &conditiondata,
|
const vector<uint8_t> &conditionData,
|
||||||
const vector<float> &ExpectData) {
|
const vector<float> &ExpectData) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
|
auto condition = gCpu->addTensor(conditionShape, DataType::UInt8);
|
||||||
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
auto inputX = gCpu->addTensor(inputXShape, DataType::Float32);
|
||||||
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
auto inputY = gCpu->addTensor(inputYShape, DataType::Float32);
|
||||||
|
|
||||||
gCpu->dataMalloc();
|
gCpu->dataMalloc();
|
||||||
condition->copyin(conditiondata); //
|
condition->copyin(conditionData); //
|
||||||
inputx->copyin(inputxdata);
|
inputX->copyin(inputXData);
|
||||||
inputy->copyin(inputydata); //
|
inputY->copyin(inputYData); //
|
||||||
|
|
||||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
auto conditionGpu = gCuda->cloneTensor(condition);
|
auto conditionGpu = gCuda->cloneTensor(condition);
|
||||||
auto inputxGpu = gCuda->cloneTensor(inputx);
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
||||||
auto inputyGpu = gCuda->cloneTensor(inputy);
|
auto inputYGpu = gCuda->cloneTensor(inputY);
|
||||||
|
|
||||||
auto op = gCuda->addOp<WhereObj>(inputxGpu, inputyGpu, conditionGpu,
|
auto op = gCuda->addOp<WhereObj>(inputXGpu, inputYGpu, conditionGpu,
|
||||||
nullptr); // WhereObj
|
nullptr); // WhereObj
|
||||||
gCuda->dataMalloc();
|
gCuda->dataMalloc();
|
||||||
conditionGpu->copyin(conditiondata);
|
conditionGpu->copyin(conditionData);
|
||||||
inputxGpu->copyin(inputxdata);
|
inputXGpu->copyin(inputXData);
|
||||||
inputyGpu->copyin(inputydata);
|
inputYGpu->copyin(inputYData);
|
||||||
cudaRuntime->run(gCuda);
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move data from gpu to cpu
|
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
oCpu->printData(); //->printData
|
oCpu->printData(); //->printData
|
||||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue