diff --git a/include/cuda/cuda_layernorm.h b/include/cuda/cuda_layernorm.h index 997c8a06..b6829d09 100644 --- a/include/cuda/cuda_layernorm.h +++ b/include/cuda/cuda_layernorm.h @@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps, void LaynormKernel(const float *input, const float *scale, const float eps, int size, int scaleSize, const int dimsize, const int stride, float *output); +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output, const half *bias, int biasSize); +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output); }; // namespace infini diff --git a/src/kernels/cuda/layer_norm.cc b/src/kernels/cuda/layer_norm.cc index a42ac36a..5c1c76c3 100644 --- a/src/kernels/cuda/layer_norm.cc +++ b/src/kernels/cuda/layer_norm.cc @@ -24,17 +24,34 @@ class LayerNormCuda : public CudaKernelWithoutConfig { int dimsize = dims[op->getAxis()]; int size = op->getOutput(0)->size(); int scaleSize = op->getInputs(1)->size(); - if (op->numInputs() == 3) { - void *const biasData = (op->getInputs(2)->getRawDataPtr()); - int biasSize = op->getInputs(2)->size(); - // printf("kernel bias:true:%d\n", 1); - LaynormKernel((float *)inputData, (float *)scaleData, eps, size, - scaleSize, dimsize, stride, (float *)outputData, - (float *)biasData, biasSize); - } else { - // printf("kernel bias:false:%d\n", 0); - LaynormKernel((float *)inputData, (float *)scaleData, eps, size, - scaleSize, dimsize, stride, (float *)outputData); + if (op->getDType() == DataType::Float32) { + if (op->numInputs() == 3) { + void *const biasData = + (op->getInputs(2)->getRawDataPtr()); + int biasSize = op->getInputs(2)->size(); + // printf("kernel bias:true:%d\n", 1); + LaynormKernel((float *)inputData, (float *)scaleData, eps, size, + scaleSize, dimsize, stride, (float *)outputData, + (float *)biasData, biasSize); + } else { + // printf("kernel bias:false:%d\n", 0); + LaynormKernel((float *)inputData, (float *)scaleData, eps, size, + scaleSize, dimsize, stride, (float *)outputData); + } + } else if (op->getDType() == DataType::Float16) { + if (op->numInputs() == 3) { + void *const biasData = + (op->getInputs(2)->getRawDataPtr()); + int biasSize = op->getInputs(2)->size(); + // printf("kernel bias:true:%d\n", 1); + LaynormKernel((half *)inputData, (half *)scaleData, eps, size, + scaleSize, dimsize, stride, (half *)outputData, + (half *)biasData, biasSize); + } else { + // printf("kernel bias:false:%d\n", 0); + LaynormKernel((half *)inputData, (half *)scaleData, eps, size, + scaleSize, dimsize, stride, (half *)outputData); + } } } }; diff --git a/src/kernels/cuda/layer_norm.cu b/src/kernels/cuda/layer_norm.cu index c5e6e492..26f06e28 100644 --- a/src/kernels/cuda/layer_norm.cu +++ b/src/kernels/cuda/layer_norm.cu @@ -1,43 +1,41 @@ #include "cuda/cuda_common.h" #include -template +template __launch_bounds__(BLOCK_DIM) __global__ - void blockLaynormKernel(const float *input, const float *scale, - const int dimsize, const int stride, float *output, - const float eps, int scaleSize, const float *bias, - int biasSize) { + void blockLaynormKernel(const T *input, const T *scale, const int dimsize, + const int stride, T *output, const T eps, + int scaleSize, const T *bias, int biasSize) { // len(scale) = len(bias) = dimsize int tmp = blockIdx.x % stride; int tid = (blockIdx.x - tmp) * dimsize + tmp; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float mu; - float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + __shared__ T mu; + T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - mu = muBlock / dimsize; + mu = muBlock * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { sigma2Partial += (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; - __shared__ float sigma2; - float sigma2Block = - BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + __shared__ T sigma2; + T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - sigma2 = sigma2Block / dimsize; + sigma2 = sigma2Block * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); if (biasSize == dimsize) { @@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[threadIdx.x + ph * BLOCK_DIM] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM]; } } else { @@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM]; } } @@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[threadIdx.x + ph * BLOCK_DIM] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[0]; } } else { @@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - - mu) / - sqrt(sigma2 + eps) + + mu) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast(sigma2 + eps)))) + bias[0]; } } } } //----------------- -template +template __launch_bounds__(BLOCK_DIM) __global__ - void blockLaynormKernel(const float *input, const float *scale, - const int dimsize, const int stride, float *output, - const float eps, int scaleSize) { + void blockLaynormKernel(const T *input, const T *scale, const int dimsize, + const int stride, T *output, const T eps, + int scaleSize) { // len(scale) = len(bias) = dimsize int tmp = blockIdx.x % stride; int tid = (blockIdx.x - tmp) * dimsize + tmp; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float mu; - float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + __shared__ T mu; + T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - mu = muBlock / dimsize; + mu = muBlock * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { sigma2Partial += (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); } - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; - __shared__ float sigma2; - float sigma2Block = - BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + __shared__ T sigma2; + T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); if (threadIdx.x == 0) { // must set threadIdx.x = 0 write the output to memory - sigma2 = sigma2Block / dimsize; + sigma2 = sigma2Block * static_cast(__fdividef(1.0F, dimsize)); } __syncthreads(); if (scaleSize == dimsize) { @@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__ output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[threadIdx.x + ph * BLOCK_DIM] * - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / - sqrt(sigma2 + eps); + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * + static_cast( + __fdividef(1.0F, sqrt(static_cast(sigma2 + eps)))); } } else { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = scale[0] * - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / - sqrt(sigma2 + eps); + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * + static_cast( + __fdividef(1.0F, sqrt(static_cast(sigma2 + eps)))); } } } @@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) { } return val; } -template -__global__ void warpLaynormKernel(const float *input, const float *scale, +template +__global__ void warpLaynormKernel(const T *input, const T *scale, const int dimsize, const int stride, - float *output, const float eps, int scaleSize, - int otherSize, const float *bias, - int biasSize) { + T *output, const T eps, int scaleSize, + int otherSize, const T *bias, int biasSize) { int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; if (otherIdx < otherSize) { - __shared__ float muTotal[BLOCK_DIM_y]; - __shared__ float sigma2Total[BLOCK_DIM_y]; + __shared__ T muTotal[BLOCK_DIM_y]; + __shared__ T sigma2Total[BLOCK_DIM_y]; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; } - muPartial = WarpAllReduce(muPartial); + muPartial = WarpAllReduce(muPartial); if (threadIdx.x == 0) - muTotal[threadIdx.y] = muPartial / dimsize; + muTotal[threadIdx.y] = + muPartial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { sigma2Partial += @@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, muTotal[threadIdx.y]); } - sigma2Partial = WarpAllReduce(sigma2Partial); + sigma2Partial = WarpAllReduce(sigma2Partial); if (threadIdx.x == 0) - sigma2Total[threadIdx.y] = sigma2Partial / dimsize; + sigma2Total[threadIdx.y] = + sigma2Partial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- if (biasSize == dimsize) { @@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[threadIdx.x + ph * BLOCK_DIM_x] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM_x]; } } else { @@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[threadIdx.x + ph * BLOCK_DIM_x]; } } @@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[threadIdx.x + ph * BLOCK_DIM_x] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[0]; } } else { @@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps) + + muTotal[threadIdx.y]) * + static_cast(__fdividef( + 1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))) + bias[0]; } } } } } -template -__global__ void warpLaynormKernel(const float *input, const float *scale, +template +__global__ void warpLaynormKernel(const T *input, const T *scale, const int dimsize, const int stride, - float *output, const float eps, int scaleSize, + T *output, const T eps, int scaleSize, int otherSize) { int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; if (otherIdx < otherSize) { - __shared__ float muTotal[BLOCK_DIM_y]; - __shared__ float sigma2Total[BLOCK_DIM_y]; + __shared__ T muTotal[BLOCK_DIM_y]; + __shared__ T sigma2Total[BLOCK_DIM_y]; - float muPartial = 0.0f; + T muPartial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; } - muPartial = WarpAllReduce(muPartial); + muPartial = WarpAllReduce(muPartial); if (threadIdx.x == 0) - muTotal[threadIdx.y] = muPartial / dimsize; + muTotal[threadIdx.y] = + muPartial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- - float sigma2Partial = 0.0f; + T sigma2Partial = 0.0f; for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { sigma2Partial += @@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, muTotal[threadIdx.y]); } - sigma2Partial = WarpAllReduce(sigma2Partial); + sigma2Partial = WarpAllReduce(sigma2Partial); if (threadIdx.x == 0) - sigma2Total[threadIdx.y] = sigma2Partial / dimsize; + sigma2Total[threadIdx.y] = + sigma2Partial * static_cast(__fdividef(1.0F, dimsize)); //-------------------------------------------- if (scaleSize == dimsize) { @@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = scale[threadIdx.x + ph * BLOCK_DIM_x] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps); + muTotal[threadIdx.y]) * + static_cast( + __fdividef(1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))); } } else { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { @@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale, output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = scale[0] * (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - - muTotal[threadIdx.y]) / - sqrt(sigma2Total[threadIdx.y] + eps); + muTotal[threadIdx.y]) * + static_cast( + __fdividef(1.0F, sqrt(static_cast( + sigma2Total[threadIdx.y] + eps)))); } } } @@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, if (dimsize > 1024) { int BLOCK_DIM = 1024; - blockLaynormKernel<1024> + blockLaynormKernel <<>>(input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize); } else if (dimsize > 31) { @@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<32, 32><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 15) { @@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<16, 64><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 7) { @@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<8, 128><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else { @@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<4, 256><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } @@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, if (dimsize > 1024) { int BLOCK_DIM = 1024; - blockLaynormKernel<1024><<>>( + blockLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<32, 32><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<16, 64><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<8, 128><<>>( + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else { int BLOCK_DIM_x = 4; @@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<4, 256><<>>( + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } +} +//----------------- +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output, const half *bias, int biasSize) { + int num_block = size / dimsize; + if (dimsize > 1024) { + int BLOCK_DIM = 1024; + + blockLaynormKernel + <<>>(input, scale, dimsize, stride, output, + eps, scaleSize, bias, biasSize); + } else if (dimsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } else if (dimsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } else if (dimsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); + } +} + +void LaynormKernel(const half *input, const half *scale, const half eps, + int size, int scaleSize, const int dimsize, const int stride, + half *output) { + int num_block = size / dimsize; + if (dimsize > 1024) { + int BLOCK_DIM = 1024; + + blockLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize); + } else if (dimsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } else if (dimsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } else if (dimsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( + input, scale, dimsize, stride, output, eps, scaleSize, num_block); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLaynormKernel<<>>( input, scale, dimsize, stride, output, eps, scaleSize, num_block); } } diff --git a/src/operators/layer_norm.cc b/src/operators/layer_norm.cc index 68649215..5109c79b 100644 --- a/src/operators/layer_norm.cc +++ b/src/operators/layer_norm.cc @@ -27,10 +27,7 @@ optional> LayerNormObj::inferShape(const TensorVec &inputs) { vector LayerNormObj::inferDataType(const TensorVec &inputs) const { IT_ASSERT(inputs.size() == 2 || inputs.size() == 3); - IT_ASSERT(inputs[1]->getDType() == DataType::Float32); - if (inputs.size() == 3) { - IT_ASSERT(inputs[2]->getDType() == DataType::Float32); - } + return {inputs[0]->getDType()}; } diff --git a/test/kernels/cuda/test_cuda_layernorm.cc b/test/kernels/cuda/test_cuda_layernorm.cc index 18b8c4df..e2af489e 100644 --- a/test/kernels/cuda/test_cuda_layernorm.cc +++ b/test/kernels/cuda/test_cuda_layernorm.cc @@ -8,7 +8,7 @@ namespace infini { -void test_layernorm( +void test_layernormFp32( const Shape &inputShape, const vector &inputData, const Shape &scaleShape, const vector &scaleData, float eps, int axis, int stash_type, const vector &ExpectData, @@ -77,9 +77,78 @@ void test_layernorm( EXPECT_TRUE(oCpu->equalData(ExpectData)); } } +void test_layernormFp16( + const Shape &inputShape, + const std::function &generator, + const Shape &scaleShape, float eps, int axis, int stash_type, + const vector &ExpectData, + const std::optional &bShape = std::nullopt) { -TEST(CUDA_Layernorm, run) { - test_layernorm( + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + if (bShape.has_value()) { + Shape biasShape = *bShape; + + auto bias = gCpu->addTensor(biasShape, DataType::Float16); + auto input = gCpu->addTensor(inputShape, DataType::Float16); + auto scale = gCpu->addTensor(scaleShape, DataType::Float16); + gCpu->dataMalloc(); + bias->setData(generator); + // bias->printData(); + input->setData(generator); + scale->setData(generator); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + auto biasGpu = gCuda->cloneTensor(bias); + auto inputGpu = gCuda->cloneTensor(input); + auto scaleGpu = gCuda->cloneTensor(scale); + // gCpu->cloneTensor(biasGpu)->printData(); + auto op = + gCuda->addOp(inputGpu, scaleGpu, nullptr, biasGpu, + eps, axis, stash_type); // LayernormObj + gCuda->dataMalloc(); + biasGpu->setData(generator); + // gCpu->cloneTensor(biasGpu)->printData(); + inputGpu->setData(generator); + scaleGpu->setData(generator); + cudaRuntime->run(gCuda); + + auto oCpu = + gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); + } else { + + auto input = gCpu->addTensor(inputShape, DataType::Float16); + auto scale = gCpu->addTensor(scaleShape, DataType::Float16); + gCpu->dataMalloc(); + + input->setData(generator); + scale->setData(generator); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto scaleGpu = gCuda->cloneTensor(scale); + auto op = + gCuda->addOp(inputGpu, scaleGpu, nullptr, nullptr, + eps, axis, stash_type); // LayernormObj + gCuda->dataMalloc(); + + inputGpu->setData(generator); + scaleGpu->setData(generator); + cudaRuntime->run(gCuda); + + auto oCpu = + gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); + } +} + +TEST(CUDA_LayernormFp32, run) { + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -94,7 +163,7 @@ TEST(CUDA_Layernorm, run) { -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}, Shape{3}, vector{0, 0, 0}); - test_layernorm( + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -109,7 +178,7 @@ TEST(CUDA_Layernorm, run) { -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679}, Shape{3}, vector{0.3, 0.2, 0.5}); - test_layernorm( + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -124,7 +193,7 @@ TEST(CUDA_Layernorm, run) { -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207}, Shape{3}, vector{0.3, 0.2, 0.5}); - test_layernorm( + test_layernormFp32( Shape{2, 3, 2, 3}, vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., @@ -141,6 +210,15 @@ TEST(CUDA_Layernorm, run) { 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}); +} // python output +TEST(CUDA_LayernormFp16, run) { + test_layernormFp16(Shape{2, 3, 2, 3}, ValGenerator<2>(), Shape{3}, 1e-5, 3, + 1, vector{2., 2., 2., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2., + 2., 2., 2., 2., 2., 2., 2., 2., 2.}, + Shape{3}); + } // python output } // namespace infini