forked from jiuyuan/InfiniTensor
add layernorm fp16
This commit is contained in:
parent
8b2e3b8e19
commit
fda0a5f982
|
@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
int size, int scaleSize, const int dimsize, const int stride,
|
int size, int scaleSize, const int dimsize, const int stride,
|
||||||
float *output);
|
float *output);
|
||||||
|
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||||
|
int size, int scaleSize, const int dimsize, const int stride,
|
||||||
|
half *output, const half *bias, int biasSize);
|
||||||
|
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||||
|
int size, int scaleSize, const int dimsize, const int stride,
|
||||||
|
half *output);
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -24,17 +24,34 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
|
||||||
int dimsize = dims[op->getAxis()];
|
int dimsize = dims[op->getAxis()];
|
||||||
int size = op->getOutput(0)->size();
|
int size = op->getOutput(0)->size();
|
||||||
int scaleSize = op->getInputs(1)->size();
|
int scaleSize = op->getInputs(1)->size();
|
||||||
if (op->numInputs() == 3) {
|
if (op->getDType() == DataType::Float32) {
|
||||||
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
if (op->numInputs() == 3) {
|
||||||
int biasSize = op->getInputs(2)->size();
|
void *const biasData =
|
||||||
// printf("kernel bias:true:%d\n", 1);
|
(op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
int biasSize = op->getInputs(2)->size();
|
||||||
scaleSize, dimsize, stride, (float *)outputData,
|
// printf("kernel bias:true:%d\n", 1);
|
||||||
(float *)biasData, biasSize);
|
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
||||||
} else {
|
scaleSize, dimsize, stride, (float *)outputData,
|
||||||
// printf("kernel bias:false:%d\n", 0);
|
(float *)biasData, biasSize);
|
||||||
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
} else {
|
||||||
scaleSize, dimsize, stride, (float *)outputData);
|
// printf("kernel bias:false:%d\n", 0);
|
||||||
|
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
||||||
|
scaleSize, dimsize, stride, (float *)outputData);
|
||||||
|
}
|
||||||
|
} else if (op->getDType() == DataType::Float16) {
|
||||||
|
if (op->numInputs() == 3) {
|
||||||
|
void *const biasData =
|
||||||
|
(op->getInputs(2)->getRawDataPtr<void *>());
|
||||||
|
int biasSize = op->getInputs(2)->size();
|
||||||
|
// printf("kernel bias:true:%d\n", 1);
|
||||||
|
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
|
||||||
|
scaleSize, dimsize, stride, (half *)outputData,
|
||||||
|
(half *)biasData, biasSize);
|
||||||
|
} else {
|
||||||
|
// printf("kernel bias:false:%d\n", 0);
|
||||||
|
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
|
||||||
|
scaleSize, dimsize, stride, (half *)outputData);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,43 +1,41 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
|
|
||||||
template <int BLOCK_DIM>
|
template <typename T, int BLOCK_DIM>
|
||||||
__launch_bounds__(BLOCK_DIM) __global__
|
__launch_bounds__(BLOCK_DIM) __global__
|
||||||
void blockLaynormKernel(const float *input, const float *scale,
|
void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
|
||||||
const int dimsize, const int stride, float *output,
|
const int stride, T *output, const T eps,
|
||||||
const float eps, int scaleSize, const float *bias,
|
int scaleSize, const T *bias, int biasSize) {
|
||||||
int biasSize) {
|
|
||||||
// len(scale) = len(bias) = dimsize
|
// len(scale) = len(bias) = dimsize
|
||||||
int tmp = blockIdx.x % stride;
|
int tmp = blockIdx.x % stride;
|
||||||
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||||
float muPartial = 0.0f;
|
T muPartial = 0.0f;
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||||
}
|
}
|
||||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
__shared__ float mu;
|
__shared__ T mu;
|
||||||
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||||
if (threadIdx.x ==
|
if (threadIdx.x ==
|
||||||
0) { // must set threadIdx.x = 0 write the output to memory
|
0) { // must set threadIdx.x = 0 write the output to memory
|
||||||
mu = muBlock / dimsize;
|
mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float sigma2Partial = 0.0f;
|
T sigma2Partial = 0.0f;
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
sigma2Partial +=
|
sigma2Partial +=
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
||||||
}
|
}
|
||||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||||
|
|
||||||
__shared__ float sigma2;
|
__shared__ T sigma2;
|
||||||
float sigma2Block =
|
T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||||
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
|
||||||
if (threadIdx.x ==
|
if (threadIdx.x ==
|
||||||
0) { // must set threadIdx.x = 0 write the output to memory
|
0) { // must set threadIdx.x = 0 write the output to memory
|
||||||
sigma2 = sigma2Block / dimsize;
|
sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (biasSize == dimsize) {
|
if (biasSize == dimsize) {
|
||||||
|
@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||||
mu) /
|
mu) *
|
||||||
sqrt(sigma2 + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||||
bias[threadIdx.x + ph * BLOCK_DIM];
|
bias[threadIdx.x + ph * BLOCK_DIM];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
scale[0] *
|
scale[0] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||||
mu) /
|
mu) *
|
||||||
sqrt(sigma2 + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||||
bias[threadIdx.x + ph * BLOCK_DIM];
|
bias[threadIdx.x + ph * BLOCK_DIM];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||||
mu) /
|
mu) *
|
||||||
sqrt(sigma2 + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||||
bias[0];
|
bias[0];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
scale[0] *
|
scale[0] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||||
mu) /
|
mu) *
|
||||||
sqrt(sigma2 + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||||
bias[0];
|
bias[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//-----------------
|
//-----------------
|
||||||
template <int BLOCK_DIM>
|
template <typename T, int BLOCK_DIM>
|
||||||
__launch_bounds__(BLOCK_DIM) __global__
|
__launch_bounds__(BLOCK_DIM) __global__
|
||||||
void blockLaynormKernel(const float *input, const float *scale,
|
void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
|
||||||
const int dimsize, const int stride, float *output,
|
const int stride, T *output, const T eps,
|
||||||
const float eps, int scaleSize) {
|
int scaleSize) {
|
||||||
// len(scale) = len(bias) = dimsize
|
// len(scale) = len(bias) = dimsize
|
||||||
int tmp = blockIdx.x % stride;
|
int tmp = blockIdx.x % stride;
|
||||||
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||||
float muPartial = 0.0f;
|
T muPartial = 0.0f;
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||||
}
|
}
|
||||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
__shared__ float mu;
|
__shared__ T mu;
|
||||||
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||||
if (threadIdx.x ==
|
if (threadIdx.x ==
|
||||||
0) { // must set threadIdx.x = 0 write the output to memory
|
0) { // must set threadIdx.x = 0 write the output to memory
|
||||||
mu = muBlock / dimsize;
|
mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float sigma2Partial = 0.0f;
|
T sigma2Partial = 0.0f;
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
sigma2Partial +=
|
sigma2Partial +=
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
||||||
}
|
}
|
||||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||||
|
|
||||||
__shared__ float sigma2;
|
__shared__ T sigma2;
|
||||||
float sigma2Block =
|
T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||||
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
|
||||||
if (threadIdx.x ==
|
if (threadIdx.x ==
|
||||||
0) { // must set threadIdx.x = 0 write the output to memory
|
0) { // must set threadIdx.x = 0 write the output to memory
|
||||||
sigma2 = sigma2Block / dimsize;
|
sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (scaleSize == dimsize) {
|
if (scaleSize == dimsize) {
|
||||||
|
@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__
|
||||||
|
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||||
sqrt(sigma2 + eps);
|
static_cast<T>(
|
||||||
|
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||||
|
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||||
scale[0] *
|
scale[0] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||||
sqrt(sigma2 + eps);
|
static_cast<T>(
|
||||||
|
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) {
|
||||||
}
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||||
__global__ void warpLaynormKernel(const float *input, const float *scale,
|
__global__ void warpLaynormKernel(const T *input, const T *scale,
|
||||||
const int dimsize, const int stride,
|
const int dimsize, const int stride,
|
||||||
float *output, const float eps, int scaleSize,
|
T *output, const T eps, int scaleSize,
|
||||||
int otherSize, const float *bias,
|
int otherSize, const T *bias, int biasSize) {
|
||||||
int biasSize) {
|
|
||||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||||
if (otherIdx < otherSize) {
|
if (otherIdx < otherSize) {
|
||||||
|
|
||||||
__shared__ float muTotal[BLOCK_DIM_y];
|
__shared__ T muTotal[BLOCK_DIM_y];
|
||||||
__shared__ float sigma2Total[BLOCK_DIM_y];
|
__shared__ T sigma2Total[BLOCK_DIM_y];
|
||||||
|
|
||||||
float muPartial = 0.0f;
|
T muPartial = 0.0f;
|
||||||
|
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||||
}
|
}
|
||||||
|
|
||||||
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
|
muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
muTotal[threadIdx.y] = muPartial / dimsize;
|
muTotal[threadIdx.y] =
|
||||||
|
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
|
|
||||||
//--------------------------------------------
|
//--------------------------------------------
|
||||||
float sigma2Partial = 0.0f;
|
T sigma2Partial = 0.0f;
|
||||||
|
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
sigma2Partial +=
|
sigma2Partial +=
|
||||||
|
@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
muTotal[threadIdx.y]);
|
muTotal[threadIdx.y]);
|
||||||
}
|
}
|
||||||
|
|
||||||
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
|
sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
|
sigma2Total[threadIdx.y] =
|
||||||
|
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
|
|
||||||
//--------------------------------------------
|
//--------------------------------------------
|
||||||
if (biasSize == dimsize) {
|
if (biasSize == dimsize) {
|
||||||
|
@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||||
(input[tid +
|
(input[tid +
|
||||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
muTotal[threadIdx.y]) /
|
muTotal[threadIdx.y]) *
|
||||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(
|
||||||
|
sigma2Total[threadIdx.y] + eps)))) +
|
||||||
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
scale[0] *
|
scale[0] *
|
||||||
(input[tid +
|
(input[tid +
|
||||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
muTotal[threadIdx.y]) /
|
muTotal[threadIdx.y]) *
|
||||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(
|
||||||
|
sigma2Total[threadIdx.y] + eps)))) +
|
||||||
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||||
(input[tid +
|
(input[tid +
|
||||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
muTotal[threadIdx.y]) /
|
muTotal[threadIdx.y]) *
|
||||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(
|
||||||
|
sigma2Total[threadIdx.y] + eps)))) +
|
||||||
bias[0];
|
bias[0];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
scale[0] *
|
scale[0] *
|
||||||
(input[tid +
|
(input[tid +
|
||||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
muTotal[threadIdx.y]) /
|
muTotal[threadIdx.y]) *
|
||||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
static_cast<T>(__fdividef(
|
||||||
|
1.0F, sqrt(static_cast<float>(
|
||||||
|
sigma2Total[threadIdx.y] + eps)))) +
|
||||||
bias[0];
|
bias[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||||
__global__ void warpLaynormKernel(const float *input, const float *scale,
|
__global__ void warpLaynormKernel(const T *input, const T *scale,
|
||||||
const int dimsize, const int stride,
|
const int dimsize, const int stride,
|
||||||
float *output, const float eps, int scaleSize,
|
T *output, const T eps, int scaleSize,
|
||||||
int otherSize) {
|
int otherSize) {
|
||||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||||
if (otherIdx < otherSize) {
|
if (otherIdx < otherSize) {
|
||||||
|
|
||||||
__shared__ float muTotal[BLOCK_DIM_y];
|
__shared__ T muTotal[BLOCK_DIM_y];
|
||||||
__shared__ float sigma2Total[BLOCK_DIM_y];
|
__shared__ T sigma2Total[BLOCK_DIM_y];
|
||||||
|
|
||||||
float muPartial = 0.0f;
|
T muPartial = 0.0f;
|
||||||
|
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||||
}
|
}
|
||||||
|
|
||||||
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
|
muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
muTotal[threadIdx.y] = muPartial / dimsize;
|
muTotal[threadIdx.y] =
|
||||||
|
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
|
|
||||||
//--------------------------------------------
|
//--------------------------------------------
|
||||||
float sigma2Partial = 0.0f;
|
T sigma2Partial = 0.0f;
|
||||||
|
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
sigma2Partial +=
|
sigma2Partial +=
|
||||||
|
@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
muTotal[threadIdx.y]);
|
muTotal[threadIdx.y]);
|
||||||
}
|
}
|
||||||
|
|
||||||
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
|
sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
|
sigma2Total[threadIdx.y] =
|
||||||
|
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||||
|
|
||||||
//--------------------------------------------
|
//--------------------------------------------
|
||||||
if (scaleSize == dimsize) {
|
if (scaleSize == dimsize) {
|
||||||
|
@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
muTotal[threadIdx.y]) /
|
muTotal[threadIdx.y]) *
|
||||||
sqrt(sigma2Total[threadIdx.y] + eps);
|
static_cast<T>(
|
||||||
|
__fdividef(1.0F, sqrt(static_cast<float>(
|
||||||
|
sigma2Total[threadIdx.y] + eps))));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||||
|
@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||||
scale[0] *
|
scale[0] *
|
||||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||||
muTotal[threadIdx.y]) /
|
muTotal[threadIdx.y]) *
|
||||||
sqrt(sigma2Total[threadIdx.y] + eps);
|
static_cast<T>(
|
||||||
|
__fdividef(1.0F, sqrt(static_cast<float>(
|
||||||
|
sigma2Total[threadIdx.y] + eps))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
if (dimsize > 1024) {
|
if (dimsize > 1024) {
|
||||||
int BLOCK_DIM = 1024;
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
blockLaynormKernel<1024>
|
blockLaynormKernel<float, 1024>
|
||||||
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||||
eps, scaleSize, bias, biasSize);
|
eps, scaleSize, bias, biasSize);
|
||||||
} else if (dimsize > 31) {
|
} else if (dimsize > 31) {
|
||||||
|
@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
bias, biasSize);
|
bias, biasSize);
|
||||||
} else if (dimsize > 15) {
|
} else if (dimsize > 15) {
|
||||||
|
@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
bias, biasSize);
|
bias, biasSize);
|
||||||
} else if (dimsize > 7) {
|
} else if (dimsize > 7) {
|
||||||
|
@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
bias, biasSize);
|
bias, biasSize);
|
||||||
} else {
|
} else {
|
||||||
|
@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
bias, biasSize);
|
bias, biasSize);
|
||||||
}
|
}
|
||||||
|
@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
if (dimsize > 1024) {
|
if (dimsize > 1024) {
|
||||||
int BLOCK_DIM = 1024;
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>(
|
blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize);
|
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||||
} else if (dimsize > 31) {
|
} else if (dimsize > 31) {
|
||||||
int BLOCK_DIM_x = 32;
|
int BLOCK_DIM_x = 32;
|
||||||
|
@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
} else if (dimsize > 15) {
|
} else if (dimsize > 15) {
|
||||||
int BLOCK_DIM_x = 16;
|
int BLOCK_DIM_x = 16;
|
||||||
|
@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
} else if (dimsize > 7) {
|
} else if (dimsize > 7) {
|
||||||
int BLOCK_DIM_x = 8;
|
int BLOCK_DIM_x = 8;
|
||||||
|
@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
} else {
|
} else {
|
||||||
int BLOCK_DIM_x = 4;
|
int BLOCK_DIM_x = 4;
|
||||||
|
@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, 1, 1);
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
|
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//-----------------
|
||||||
|
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||||
|
int size, int scaleSize, const int dimsize, const int stride,
|
||||||
|
half *output, const half *bias, int biasSize) {
|
||||||
|
int num_block = size / dimsize;
|
||||||
|
if (dimsize > 1024) {
|
||||||
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
|
blockLaynormKernel<half, 1024>
|
||||||
|
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||||
|
eps, scaleSize, bias, biasSize);
|
||||||
|
} else if (dimsize > 31) {
|
||||||
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
|
bias, biasSize);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
|
bias, biasSize);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
|
bias, biasSize);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||||
|
bias, biasSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||||
|
int size, int scaleSize, const int dimsize, const int stride,
|
||||||
|
half *output) {
|
||||||
|
int num_block = size / dimsize;
|
||||||
|
if (dimsize > 1024) {
|
||||||
|
int BLOCK_DIM = 1024;
|
||||||
|
|
||||||
|
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||||
|
} else if (dimsize > 31) {
|
||||||
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
|
} else if (dimsize > 15) {
|
||||||
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
|
} else if (dimsize > 7) {
|
||||||
|
int BLOCK_DIM_x = 8;
|
||||||
|
int BLOCK_DIM_y = 128;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||||
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
|
} else {
|
||||||
|
int BLOCK_DIM_x = 4;
|
||||||
|
int BLOCK_DIM_y = 256;
|
||||||
|
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
|
dim3 grid_dim(num_block_x, 1, 1);
|
||||||
|
|
||||||
|
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,10 +27,7 @@ optional<vector<Shape>> LayerNormObj::inferShape(const TensorVec &inputs) {
|
||||||
|
|
||||||
vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const {
|
vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const {
|
||||||
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
||||||
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
|
|
||||||
if (inputs.size() == 3) {
|
|
||||||
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
|
|
||||||
}
|
|
||||||
return {inputs[0]->getDType()};
|
return {inputs[0]->getDType()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void test_layernorm(
|
void test_layernormFp32(
|
||||||
const Shape &inputShape, const vector<float> &inputData,
|
const Shape &inputShape, const vector<float> &inputData,
|
||||||
const Shape &scaleShape, const vector<float> &scaleData, float eps,
|
const Shape &scaleShape, const vector<float> &scaleData, float eps,
|
||||||
int axis, int stash_type, const vector<float> &ExpectData,
|
int axis, int stash_type, const vector<float> &ExpectData,
|
||||||
|
@ -77,9 +77,78 @@ void test_layernorm(
|
||||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void test_layernormFp16(
|
||||||
|
const Shape &inputShape,
|
||||||
|
const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
|
const Shape &scaleShape, float eps, int axis, int stash_type,
|
||||||
|
const vector<float> &ExpectData,
|
||||||
|
const std::optional<Shape> &bShape = std::nullopt) {
|
||||||
|
|
||||||
TEST(CUDA_Layernorm, run) {
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
test_layernorm(
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
if (bShape.has_value()) {
|
||||||
|
Shape biasShape = *bShape;
|
||||||
|
|
||||||
|
auto bias = gCpu->addTensor(biasShape, DataType::Float16);
|
||||||
|
auto input = gCpu->addTensor(inputShape, DataType::Float16);
|
||||||
|
auto scale = gCpu->addTensor(scaleShape, DataType::Float16);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
bias->setData(generator);
|
||||||
|
// bias->printData();
|
||||||
|
input->setData(generator);
|
||||||
|
scale->setData(generator);
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto biasGpu = gCuda->cloneTensor(bias);
|
||||||
|
auto inputGpu = gCuda->cloneTensor(input);
|
||||||
|
auto scaleGpu = gCuda->cloneTensor(scale);
|
||||||
|
// gCpu->cloneTensor(biasGpu)->printData();
|
||||||
|
auto op =
|
||||||
|
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, biasGpu,
|
||||||
|
eps, axis, stash_type); // LayernormObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
biasGpu->setData(generator);
|
||||||
|
// gCpu->cloneTensor(biasGpu)->printData();
|
||||||
|
inputGpu->setData(generator);
|
||||||
|
scaleGpu->setData(generator);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu =
|
||||||
|
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
} else {
|
||||||
|
|
||||||
|
auto input = gCpu->addTensor(inputShape, DataType::Float16);
|
||||||
|
auto scale = gCpu->addTensor(scaleShape, DataType::Float16);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
|
||||||
|
input->setData(generator);
|
||||||
|
scale->setData(generator);
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
auto inputGpu = gCuda->cloneTensor(input);
|
||||||
|
auto scaleGpu = gCuda->cloneTensor(scale);
|
||||||
|
auto op =
|
||||||
|
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, nullptr,
|
||||||
|
eps, axis, stash_type); // LayernormObj
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
|
||||||
|
inputGpu->setData(generator);
|
||||||
|
scaleGpu->setData(generator);
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu =
|
||||||
|
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||||
|
oCpu->printData(); //->printData
|
||||||
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_LayernormFp32, run) {
|
||||||
|
test_layernormFp32(
|
||||||
Shape{2, 3, 2, 3},
|
Shape{2, 3, 2, 3},
|
||||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
|
@ -94,7 +163,7 @@ TEST(CUDA_Layernorm, run) {
|
||||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
|
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
|
||||||
Shape{3}, vector<float>{0, 0, 0});
|
Shape{3}, vector<float>{0, 0, 0});
|
||||||
test_layernorm(
|
test_layernormFp32(
|
||||||
Shape{2, 3, 2, 3},
|
Shape{2, 3, 2, 3},
|
||||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
|
@ -109,7 +178,7 @@ TEST(CUDA_Layernorm, run) {
|
||||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
|
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
|
||||||
Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
||||||
test_layernorm(
|
test_layernormFp32(
|
||||||
Shape{2, 3, 2, 3},
|
Shape{2, 3, 2, 3},
|
||||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
|
@ -124,7 +193,7 @@ TEST(CUDA_Layernorm, run) {
|
||||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207},
|
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207},
|
||||||
Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
||||||
test_layernorm(
|
test_layernormFp32(
|
||||||
Shape{2, 3, 2, 3},
|
Shape{2, 3, 2, 3},
|
||||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||||
|
@ -141,6 +210,15 @@ TEST(CUDA_Layernorm, run) {
|
||||||
0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||||
0.6123678, -0.3674207, 0.0000000, 0.6123678});
|
0.6123678, -0.3674207, 0.0000000, 0.6123678});
|
||||||
|
|
||||||
|
} // python output
|
||||||
|
TEST(CUDA_LayernormFp16, run) {
|
||||||
|
test_layernormFp16(Shape{2, 3, 2, 3}, ValGenerator<2>(), Shape{3}, 1e-5, 3,
|
||||||
|
1, vector<float>{2., 2., 2., 2., 2., 2., 2., 2., 2.,
|
||||||
|
2., 2., 2., 2., 2., 2., 2., 2., 2.,
|
||||||
|
2., 2., 2., 2., 2., 2., 2., 2., 2.,
|
||||||
|
2., 2., 2., 2., 2., 2., 2., 2., 2.},
|
||||||
|
Shape{3});
|
||||||
|
|
||||||
} // python output
|
} // python output
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue