Add layer normalization (#181)

* - add layernorm kernel

* success:add layernorm kernel and test

* fix: remove unusalble comments

* fix: modify code as reviewer suggested

* debug,modified .cu and test

* optional bias support

* overloading function

* fix bug after merging; remove time constrain in conv test

---------

Co-authored-by: kilinchange <kilinchange@163.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
xgqdut2016 2023-11-24 15:15:14 +08:00 committed by GitHub
parent 6ece3f4a77
commit a7293c12ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 782 additions and 4 deletions

View File

@ -30,6 +30,8 @@ class GraphHandlerObj {
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean, Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias, Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool training); float momentum, float eps, bool training);
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
Tensor bias, float eps, int axis, int stash_type);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw, int ceilMode); int ph, int pw, int sh, int sw, int ceilMode);

View File

@ -0,0 +1,11 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output, const float *bias, int biasSize);
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output);
}; // namespace infini

View File

@ -0,0 +1,30 @@
#pragma once
#include "core/operator.h"
namespace infini {
class LayerNormObj : public OperatorObj {
float eps;
int axis, stash_type;
public:
LayerNormObj(GraphObj *graph, Tensor input, Tensor scale, Tensor output,
Tensor bias = nullptr, float eps = 1e-5, int axis = -1,
int stash_type = 1);
OP_CLONE(LayerNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
Tensor getBias() const { return inputs.size() > 2 ? inputs[2] : nullptr; }
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return outputs.size(); }
float getEps() const { return eps; }
int getAxis() const { return axis; }
int getStashType() const { return stash_type; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
vector<DataType> inferDataType(const TensorVec &inputs) const override;
};
} // namespace infini

View File

@ -10,6 +10,8 @@ namespace infini {
Shape infer_broadcast(const Shape &A, const Shape &B); Shape infer_broadcast(const Shape &A, const Shape &B);
// Launch the real axis based on rank and current axis // Launch the real axis based on rank and current axis
int get_real_axis(const int &axis, const int &rank); int get_real_axis(const int &axis, const int &rank);
// check if tensor B is unidirectional broadcastable to tensor A
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
} // namespace infini } // namespace infini
#endif #endif

View File

@ -238,6 +238,25 @@ class OnnxStub:
eps, eps,
training != 0, training != 0,
) )
elif node.op_type == "LayerNormalization":
(input, scale) = (tensors[node.input[i]] for i in [0, 1])
bias = None if len(node.input) < 3 else tensors[node.input[2]]
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1}
)
(axis, eps, stash_type) = (
attributes[name] for name in ["axis", "epsilon", "stash_type"]
)
tensors[node.output[0]] = self.handler.layerNormalization(
input,
scale,
output,
bias,
eps,
axis,
stash_type,
)
elif node.op_type == "MaxPool": elif node.op_type == "MaxPool":
attributes = _parse_attribute( attributes = _parse_attribute(
node, node,

View File

@ -9,6 +9,7 @@
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/expand.h" #include "operators/expand.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/layer_norm.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/pad.h" #include "operators/pad.h"
#include "operators/pooling.h" #include "operators/pooling.h"
@ -96,6 +97,23 @@ Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
} }
} }
Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
Tensor output, Tensor bias,
float eps, int axis,
int stash_type) {
if (output) {
g->addOpWithOutputs<LayerNormObj>(std::move(input), std::move(scale),
output, std::move(bias), eps, axis,
stash_type);
return output;
} else {
return g
->addOp<LayerNormObj>(std::move(input), std::move(scale), output,
std::move(bias), eps, axis, stash_type)
->getOutput();
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw, Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw, int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) { int ceilMode) {

View File

@ -466,6 +466,7 @@ void init_graph_builder(py::module &m) {
.def("convTransposed2d", &Handler::convTransposed2d, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move)
.def("matmul", &Handler::matmul, policy::move) .def("matmul", &Handler::matmul, policy::move)
.def("batchNormalization", &Handler::batchNormalization, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("layerNormalization", &Handler::layerNormalization, policy::move)
.def("maxPool", &Handler::maxPool, policy::move) .def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move) .def("add", &Handler::add, policy::move)

View File

@ -0,0 +1,45 @@
#include "operators/layer_norm.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_layernorm.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class LayerNormCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LayerNormObj>(_op);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const scaleData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &opOutputShape = op->getOutput()->getDims();
float eps = op->getEps();
const int axis = op->getAxis();
const int stride = op->getInputs(0)->getStride().at(axis);
auto dims = op->getInputs(0)->getDims();
int dimsize = dims[op->getAxis()];
int size = op->getOutput(0)->size();
int scaleSize = op->getInputs(1)->size();
if (op->numInputs() == 3) {
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData,
(float *)biasData, biasSize);
} else {
// printf("kernel bias:false:%d\n", 0);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32,
LayerNormCuda, "LayerNorm_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,421 @@
#include "cuda/cuda_common.h"
#include <cub/cub.cuh>
template <int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride, float *output,
const float eps, int scaleSize, const float *bias,
int biasSize) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize;
}
__syncthreads();
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ float sigma2;
float sigma2Block =
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize;
}
__syncthreads();
if (biasSize == dimsize) {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[threadIdx.x + ph * BLOCK_DIM];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[threadIdx.x + ph * BLOCK_DIM];
}
}
} else {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[0];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[0];
}
}
}
}
//-----------------
template <int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride, float *output,
const float eps, int scaleSize) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize;
}
__syncthreads();
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ float sigma2;
float sigma2Block =
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize;
}
__syncthreads();
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
sqrt(sigma2 + eps);
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
sqrt(sigma2 + eps);
}
}
}
//-----------------
template <typename T> struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return a + b;
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride,
float *output, const float eps, int scaleSize,
int otherSize, const float *bias,
int biasSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
}
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize;
//--------------------------------------------
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]);
}
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
//--------------------------------------------
if (biasSize == dimsize) {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[threadIdx.x + ph * BLOCK_DIM_x];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[threadIdx.x + ph * BLOCK_DIM_x];
}
}
} else {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[0];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[0];
}
}
}
}
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride,
float *output, const float eps, int scaleSize,
int otherSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
}
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize;
//--------------------------------------------
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]);
}
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
//--------------------------------------------
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps);
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps);
}
}
}
}
namespace infini {
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output, const float *bias, int biasSize) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
}
}
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
}
}
} // namespace infini

View File

@ -0,0 +1,64 @@
#include "operators/layer_norm.h"
#include "utils/operator_utils.h"
namespace infini {
LayerNormObj::LayerNormObj(GraphObj *graph, Tensor input, Tensor scale,
Tensor output, [[maybe_unused]] Tensor bias,
float eps, int axis_, int stash_type)
: OperatorObj(OpType::LayerNormalization,
bias ? TensorVec{input, scale, bias}
: TensorVec{input, scale},
{output}),
eps(eps), stash_type(stash_type) {
const auto size = input->getRank();
axis = get_real_axis(axis_, size);
IT_ASSERT(
is_unidirectional_broadcasting(input->getDims(), scale->getDims()));
if (bias) {
IT_ASSERT(
is_unidirectional_broadcasting(input->getDims(), bias->getDims()));
}
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> LayerNormObj::inferShape(const TensorVec &inputs) {
return {{inputs[0]->getDims()}};
}
vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
if (inputs.size() == 3) {
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
}
return {inputs[0]->getDType()};
}
std::string LayerNormObj::toString() const {
std::ostringstream os;
os << "layerNormalization[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "axis=" << axis << ",";
os << "eps=" << eps << ",";
os << "stash_type=" << stash_type << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "scale=" << inputs[1]->getGuid() << ",";
// os << "bias=" << inputs[2]->getGuid() << ",";
os << "output=";
for (auto output : outputs)
os << output->getGuid() << ",";
return os.str();
}
vector<int> LayerNormObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> LayerNormObj::getOpAttrVector() const {
return {type.underlying(), axis, stash_type};
}
} // namespace infini

View File

@ -41,4 +41,27 @@ int get_real_axis(const int &axis, const int &rank) {
} }
return newAxis; return newAxis;
} }
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) {
// check if tensor B is unidirectional broadcastable to tensor A
auto B_ = B;
int rankA = A.size();
int rankB = B.size();
if (rankA < rankB) {
return false;
}
if (rankA > rankB) {
for (auto i = 0; i < rankA - rankB; ++i) {
B_.insert(B_.begin(), 1);
}
}
for (auto i = 0; i < rankA; ++i) {
if (A[i] == B_[i] || B_[i] == 1) {
continue;
} else {
return false;
}
}
return true;
}
} // namespace infini } // namespace infini

View File

@ -0,0 +1,146 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/layer_norm.h"
#include "test.h"
namespace infini {
void test_layernorm(
const Shape &inputShape, const vector<float> &inputData,
const Shape &scaleShape, const vector<float> &scaleData, float eps,
int axis, int stash_type, const vector<float> &ExpectData,
const std::optional<Shape> &bShape = std::nullopt,
const std::optional<std::vector<float>> &biasData = std::nullopt) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
if (bShape.has_value() && biasData.has_value()) {
Shape biasShape = *bShape;
auto bias = gCpu->addTensor(biasShape, DataType::Float32);
auto input = gCpu->addTensor(inputShape, DataType::Float32);
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
gCpu->dataMalloc();
bias->copyin(*biasData); //
// bias->printData();
input->copyin(inputData);
scale->copyin(scaleData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto biasGpu = gCuda->cloneTensor(bias);
auto inputGpu = gCuda->cloneTensor(input);
auto scaleGpu = gCuda->cloneTensor(scale);
// gCpu->cloneTensor(biasGpu)->printData();
auto op =
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, biasGpu,
eps, axis, stash_type); // LayernormObj
gCuda->dataMalloc();
biasGpu->copyin(*biasData);
// gCpu->cloneTensor(biasGpu)->printData();
inputGpu->copyin(inputData);
scaleGpu->copyin(scaleData);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
} else {
auto input = gCpu->addTensor(inputShape, DataType::Float32);
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
gCpu->dataMalloc();
input->copyin(inputData);
scale->copyin(scaleData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
auto scaleGpu = gCuda->cloneTensor(scale);
auto op =
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, nullptr,
eps, axis, stash_type); // LayernormObj
gCuda->dataMalloc();
inputGpu->copyin(inputData);
scaleGpu->copyin(scaleData);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
}
TEST(CUDA_Layernorm, run) {
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
vector<float>{
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
Shape{3}, vector<float>{0, 0, 0});
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
vector<float>{
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
Shape{3}, vector<float>{0.3, 0.2, 0.5});
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{1}, vector<float>{0.3}, 1e-5, 3, 1,
vector<float>{
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207},
Shape{3}, vector<float>{0.3, 0.2, 0.5});
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
vector<float>{-0.3674207, 0.0000000, 0.6123678, -0.3674207,
0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207,
0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207,
0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678});
} // python output
} // namespace infini

View File

@ -53,10 +53,6 @@ TEST(Conv, NaiveCPU) {
i0->setData(IncrementalGenerator()); i0->setData(IncrementalGenerator());
w0->setData(IncrementalGenerator()); w0->setData(IncrementalGenerator());
runtime->run(g, true, true); runtime->run(g, true, true);
double perfTime = runtime->getPerfTime(g);
// The example Conv takes 0.015ms with one core
EXPECT_GT(perfTime, 0);
EXPECT_LT(perfTime, 5); // FIXME: why may it cost 4.8 ms sometimes
// check answer // check answer
auto ans = auto ans =
make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::UInt32, runtime);