add layernorm fp16

This commit is contained in:
xgqdut2016 2023-12-11 15:05:34 +08:00
parent 8b2e3b8e19
commit fda0a5f982
5 changed files with 327 additions and 110 deletions

View File

@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
void LaynormKernel(const float *input, const float *scale, const float eps, void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride, int size, int scaleSize, const int dimsize, const int stride,
float *output); float *output);
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output, const half *bias, int biasSize);
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output);
}; // namespace infini }; // namespace infini

View File

@ -24,8 +24,10 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
int dimsize = dims[op->getAxis()]; int dimsize = dims[op->getAxis()];
int size = op->getOutput(0)->size(); int size = op->getOutput(0)->size();
int scaleSize = op->getInputs(1)->size(); int scaleSize = op->getInputs(1)->size();
if (op->getDType() == DataType::Float32) {
if (op->numInputs() == 3) { if (op->numInputs() == 3) {
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>()); void *const biasData =
(op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size(); int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1); // printf("kernel bias:true:%d\n", 1);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size, LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
@ -36,6 +38,21 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
LaynormKernel((float *)inputData, (float *)scaleData, eps, size, LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData); scaleSize, dimsize, stride, (float *)outputData);
} }
} else if (op->getDType() == DataType::Float16) {
if (op->numInputs() == 3) {
void *const biasData =
(op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1);
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
scaleSize, dimsize, stride, (half *)outputData,
(half *)biasData, biasSize);
} else {
// printf("kernel bias:false:%d\n", 0);
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
scaleSize, dimsize, stride, (half *)outputData);
}
}
} }
}; };

View File

@ -1,43 +1,41 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include <cub/cub.cuh> #include <cub/cub.cuh>
template <int BLOCK_DIM> template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ __launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale, void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
const int dimsize, const int stride, float *output, const int stride, T *output, const T eps,
const float eps, int scaleSize, const float *bias, int scaleSize, const T *bias, int biasSize) {
int biasSize) {
// len(scale) = len(bias) = dimsize // len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride; int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp; int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu; __shared__ T mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize; mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ float sigma2; __shared__ T sigma2;
float sigma2Block = T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize; sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
if (biasSize == dimsize) { if (biasSize == dimsize) {
@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] * scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM]; bias[threadIdx.x + ph * BLOCK_DIM];
} }
} else { } else {
@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM]; bias[threadIdx.x + ph * BLOCK_DIM];
} }
} }
@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] * scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[0]; bias[0];
} }
} else { } else {
@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) / mu) *
sqrt(sigma2 + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[0]; bias[0];
} }
} }
} }
} }
//----------------- //-----------------
template <int BLOCK_DIM> template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ __launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale, void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
const int dimsize, const int stride, float *output, const int stride, T *output, const T eps,
const float eps, int scaleSize) { int scaleSize) {
// len(scale) = len(bias) = dimsize // len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride; int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp; int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu; __shared__ T mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize; mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
} }
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce; typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ float sigma2; __shared__ T sigma2;
float sigma2Block = T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x == if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory 0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize; sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
} }
__syncthreads(); __syncthreads();
if (scaleSize == dimsize) { if (scaleSize == dimsize) {
@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] * scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
sqrt(sigma2 + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
} }
} else { } else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
sqrt(sigma2 + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
} }
} }
} }
@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) {
} }
return val; return val;
} }
template <int BLOCK_DIM_x, int BLOCK_DIM_y> template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale, __global__ void warpLaynormKernel(const T *input, const T *scale,
const int dimsize, const int stride, const int dimsize, const int stride,
float *output, const float eps, int scaleSize, T *output, const T eps, int scaleSize,
int otherSize, const float *bias, int otherSize, const T *bias, int biasSize) {
int biasSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) { if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y]; __shared__ T muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y]; __shared__ T sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
} }
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial); muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize; muTotal[threadIdx.y] =
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
muTotal[threadIdx.y]); muTotal[threadIdx.y]);
} }
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial); sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize; sigma2Total[threadIdx.y] =
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
if (biasSize == dimsize) { if (biasSize == dimsize) {
@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[threadIdx.x + ph * BLOCK_DIM_x] * scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM_x]; bias[threadIdx.x + ph * BLOCK_DIM_x];
} }
} else { } else {
@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[0] * scale[0] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM_x]; bias[threadIdx.x + ph * BLOCK_DIM_x];
} }
} }
@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[threadIdx.x + ph * BLOCK_DIM_x] * scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[0]; bias[0];
} }
} else { } else {
@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[0] * scale[0] *
(input[tid + (input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] - (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps) + static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[0]; bias[0];
} }
} }
} }
} }
} }
template <int BLOCK_DIM_x, int BLOCK_DIM_y> template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale, __global__ void warpLaynormKernel(const T *input, const T *scale,
const int dimsize, const int stride, const int dimsize, const int stride,
float *output, const float eps, int scaleSize, T *output, const T eps, int scaleSize,
int otherSize) { int otherSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) { if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y]; __shared__ T muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y]; __shared__ T sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f; T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]; muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
} }
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial); muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize; muTotal[threadIdx.y] =
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
float sigma2Partial = 0.0f; T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial += sigma2Partial +=
@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
muTotal[threadIdx.y]); muTotal[threadIdx.y]);
} }
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial); sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0) if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize; sigma2Total[threadIdx.y] =
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
//-------------------------------------------- //--------------------------------------------
if (scaleSize == dimsize) { if (scaleSize == dimsize) {
@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] * scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps))));
} }
} else { } else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] * scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] - (input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) / muTotal[threadIdx.y]) *
sqrt(sigma2Total[threadIdx.y] + eps); static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps))));
} }
} }
} }
@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) { if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<1024> blockLaynormKernel<float, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output, <<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize); eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 15) { } else if (dimsize > 15) {
@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 7) { } else if (dimsize > 7) {
@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else { } else {
@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block, input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} }
@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) { if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>( blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize); input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
}
}
//-----------------
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output, const half *bias, int biasSize) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
}
}
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block); input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} }
} }

View File

@ -27,10 +27,7 @@ optional<vector<Shape>> LayerNormObj::inferShape(const TensorVec &inputs) {
vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const { vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3); IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
if (inputs.size() == 3) {
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
}
return {inputs[0]->getDType()}; return {inputs[0]->getDType()};
} }

View File

@ -8,7 +8,7 @@
namespace infini { namespace infini {
void test_layernorm( void test_layernormFp32(
const Shape &inputShape, const vector<float> &inputData, const Shape &inputShape, const vector<float> &inputData,
const Shape &scaleShape, const vector<float> &scaleData, float eps, const Shape &scaleShape, const vector<float> &scaleData, float eps,
int axis, int stash_type, const vector<float> &ExpectData, int axis, int stash_type, const vector<float> &ExpectData,
@ -77,9 +77,78 @@ void test_layernorm(
EXPECT_TRUE(oCpu->equalData(ExpectData)); EXPECT_TRUE(oCpu->equalData(ExpectData));
} }
} }
void test_layernormFp16(
const Shape &inputShape,
const std::function<void(void *, size_t, DataType)> &generator,
const Shape &scaleShape, float eps, int axis, int stash_type,
const vector<float> &ExpectData,
const std::optional<Shape> &bShape = std::nullopt) {
TEST(CUDA_Layernorm, run) { Runtime runtime = NativeCpuRuntimeObj::getInstance();
test_layernorm( Graph gCpu = make_ref<GraphObj>(runtime);
if (bShape.has_value()) {
Shape biasShape = *bShape;
auto bias = gCpu->addTensor(biasShape, DataType::Float16);
auto input = gCpu->addTensor(inputShape, DataType::Float16);
auto scale = gCpu->addTensor(scaleShape, DataType::Float16);
gCpu->dataMalloc();
bias->setData(generator);
// bias->printData();
input->setData(generator);
scale->setData(generator);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto biasGpu = gCuda->cloneTensor(bias);
auto inputGpu = gCuda->cloneTensor(input);
auto scaleGpu = gCuda->cloneTensor(scale);
// gCpu->cloneTensor(biasGpu)->printData();
auto op =
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, biasGpu,
eps, axis, stash_type); // LayernormObj
gCuda->dataMalloc();
biasGpu->setData(generator);
// gCpu->cloneTensor(biasGpu)->printData();
inputGpu->setData(generator);
scaleGpu->setData(generator);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
} else {
auto input = gCpu->addTensor(inputShape, DataType::Float16);
auto scale = gCpu->addTensor(scaleShape, DataType::Float16);
gCpu->dataMalloc();
input->setData(generator);
scale->setData(generator);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
auto scaleGpu = gCuda->cloneTensor(scale);
auto op =
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, nullptr,
eps, axis, stash_type); // LayernormObj
gCuda->dataMalloc();
inputGpu->setData(generator);
scaleGpu->setData(generator);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
}
TEST(CUDA_LayernormFp32, run) {
test_layernormFp32(
Shape{2, 3, 2, 3}, Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8., vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17., 9., 10., 11., 12., 13., 14., 15., 16., 17.,
@ -94,7 +163,7 @@ TEST(CUDA_Layernorm, run) {
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678}, -0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
Shape{3}, vector<float>{0, 0, 0}); Shape{3}, vector<float>{0, 0, 0});
test_layernorm( test_layernormFp32(
Shape{2, 3, 2, 3}, Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8., vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17., 9., 10., 11., 12., 13., 14., 15., 16., 17.,
@ -109,7 +178,7 @@ TEST(CUDA_Layernorm, run) {
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679}, -0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
Shape{3}, vector<float>{0.3, 0.2, 0.5}); Shape{3}, vector<float>{0.3, 0.2, 0.5});
test_layernorm( test_layernormFp32(
Shape{2, 3, 2, 3}, Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8., vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17., 9., 10., 11., 12., 13., 14., 15., 16., 17.,
@ -124,7 +193,7 @@ TEST(CUDA_Layernorm, run) {
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207}, -0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207},
Shape{3}, vector<float>{0.3, 0.2, 0.5}); Shape{3}, vector<float>{0.3, 0.2, 0.5});
test_layernorm( test_layernormFp32(
Shape{2, 3, 2, 3}, Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8., vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17., 9., 10., 11., 12., 13., 14., 15., 16., 17.,
@ -141,6 +210,15 @@ TEST(CUDA_Layernorm, run) {
0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678}); 0.6123678, -0.3674207, 0.0000000, 0.6123678});
} // python output
TEST(CUDA_LayernormFp16, run) {
test_layernormFp16(Shape{2, 3, 2, 3}, ValGenerator<2>(), Shape{3}, 1e-5, 3,
1, vector<float>{2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2.},
Shape{3});
} // python output } // python output
} // namespace infini } // namespace infini