Add layer normalization (#181)

* - add layernorm kernel

* success:add layernorm kernel and test

* fix: remove unusalble comments

* fix: modify code as reviewer suggested

* debug,modified .cu and test

* optional bias support

* overloading function

* fix bug after merging; remove time constrain in conv test

---------

Co-authored-by: kilinchange <kilinchange@163.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
xgqdut2016 2023-11-24 15:15:14 +08:00 committed by GitHub
parent 6ece3f4a77
commit a7293c12ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 782 additions and 4 deletions

View File

@ -30,6 +30,8 @@ class GraphHandlerObj {
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias,
float momentum, float eps, bool training);
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
Tensor bias, float eps, int axis, int stash_type);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw, int ceilMode);

View File

@ -0,0 +1,11 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output, const float *bias, int biasSize);
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output);
}; // namespace infini

View File

@ -0,0 +1,30 @@
#pragma once
#include "core/operator.h"
namespace infini {
class LayerNormObj : public OperatorObj {
float eps;
int axis, stash_type;
public:
LayerNormObj(GraphObj *graph, Tensor input, Tensor scale, Tensor output,
Tensor bias = nullptr, float eps = 1e-5, int axis = -1,
int stash_type = 1);
OP_CLONE(LayerNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
Tensor getBias() const { return inputs.size() > 2 ? inputs[2] : nullptr; }
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return outputs.size(); }
float getEps() const { return eps; }
int getAxis() const { return axis; }
int getStashType() const { return stash_type; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
vector<DataType> inferDataType(const TensorVec &inputs) const override;
};
} // namespace infini

View File

@ -10,6 +10,8 @@ namespace infini {
Shape infer_broadcast(const Shape &A, const Shape &B);
// Launch the real axis based on rank and current axis
int get_real_axis(const int &axis, const int &rank);
// check if tensor B is unidirectional broadcastable to tensor A
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
} // namespace infini
#endif

View File

@ -238,6 +238,25 @@ class OnnxStub:
eps,
training != 0,
)
elif node.op_type == "LayerNormalization":
(input, scale) = (tensors[node.input[i]] for i in [0, 1])
bias = None if len(node.input) < 3 else tensors[node.input[2]]
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1}
)
(axis, eps, stash_type) = (
attributes[name] for name in ["axis", "epsilon", "stash_type"]
)
tensors[node.output[0]] = self.handler.layerNormalization(
input,
scale,
output,
bias,
eps,
axis,
stash_type,
)
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,

View File

@ -9,6 +9,7 @@
#include "operators/element_wise.h"
#include "operators/expand.h"
#include "operators/gather.h"
#include "operators/layer_norm.h"
#include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h"
@ -96,6 +97,23 @@ Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
}
}
Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
Tensor output, Tensor bias,
float eps, int axis,
int stash_type) {
if (output) {
g->addOpWithOutputs<LayerNormObj>(std::move(input), std::move(scale),
output, std::move(bias), eps, axis,
stash_type);
return output;
} else {
return g
->addOp<LayerNormObj>(std::move(input), std::move(scale), output,
std::move(bias), eps, axis, stash_type)
->getOutput();
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) {

View File

@ -466,6 +466,7 @@ void init_graph_builder(py::module &m) {
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
.def("matmul", &Handler::matmul, policy::move)
.def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("layerNormalization", &Handler::layerNormalization, policy::move)
.def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move)

View File

@ -0,0 +1,45 @@
#include "operators/layer_norm.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_layernorm.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class LayerNormCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LayerNormObj>(_op);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const scaleData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &opOutputShape = op->getOutput()->getDims();
float eps = op->getEps();
const int axis = op->getAxis();
const int stride = op->getInputs(0)->getStride().at(axis);
auto dims = op->getInputs(0)->getDims();
int dimsize = dims[op->getAxis()];
int size = op->getOutput(0)->size();
int scaleSize = op->getInputs(1)->size();
if (op->numInputs() == 3) {
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData,
(float *)biasData, biasSize);
} else {
// printf("kernel bias:false:%d\n", 0);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32,
LayerNormCuda, "LayerNorm_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,421 @@
#include "cuda/cuda_common.h"
#include <cub/cub.cuh>
template <int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride, float *output,
const float eps, int scaleSize, const float *bias,
int biasSize) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize;
}
__syncthreads();
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ float sigma2;
float sigma2Block =
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize;
}
__syncthreads();
if (biasSize == dimsize) {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[threadIdx.x + ph * BLOCK_DIM];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[threadIdx.x + ph * BLOCK_DIM];
}
}
} else {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[0];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
bias[0];
}
}
}
}
//-----------------
template <int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride, float *output,
const float eps, int scaleSize) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize;
}
__syncthreads();
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
__shared__ float sigma2;
float sigma2Block =
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize;
}
__syncthreads();
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
sqrt(sigma2 + eps);
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
sqrt(sigma2 + eps);
}
}
}
//-----------------
template <typename T> struct SumOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return a + b;
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride,
float *output, const float eps, int scaleSize,
int otherSize, const float *bias,
int biasSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
}
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize;
//--------------------------------------------
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]);
}
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
//--------------------------------------------
if (biasSize == dimsize) {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[threadIdx.x + ph * BLOCK_DIM_x];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[threadIdx.x + ph * BLOCK_DIM_x];
}
}
} else {
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[0];
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
bias[0];
}
}
}
}
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride,
float *output, const float eps, int scaleSize,
int otherSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
}
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize;
//--------------------------------------------
float sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]);
}
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
//--------------------------------------------
if (scaleSize == dimsize) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps);
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps);
}
}
}
}
namespace infini {
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output, const float *bias, int biasSize) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
}
}
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
}
}
} // namespace infini

View File

@ -0,0 +1,64 @@
#include "operators/layer_norm.h"
#include "utils/operator_utils.h"
namespace infini {
LayerNormObj::LayerNormObj(GraphObj *graph, Tensor input, Tensor scale,
Tensor output, [[maybe_unused]] Tensor bias,
float eps, int axis_, int stash_type)
: OperatorObj(OpType::LayerNormalization,
bias ? TensorVec{input, scale, bias}
: TensorVec{input, scale},
{output}),
eps(eps), stash_type(stash_type) {
const auto size = input->getRank();
axis = get_real_axis(axis_, size);
IT_ASSERT(
is_unidirectional_broadcasting(input->getDims(), scale->getDims()));
if (bias) {
IT_ASSERT(
is_unidirectional_broadcasting(input->getDims(), bias->getDims()));
}
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> LayerNormObj::inferShape(const TensorVec &inputs) {
return {{inputs[0]->getDims()}};
}
vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
if (inputs.size() == 3) {
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
}
return {inputs[0]->getDType()};
}
std::string LayerNormObj::toString() const {
std::ostringstream os;
os << "layerNormalization[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "axis=" << axis << ",";
os << "eps=" << eps << ",";
os << "stash_type=" << stash_type << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "scale=" << inputs[1]->getGuid() << ",";
// os << "bias=" << inputs[2]->getGuid() << ",";
os << "output=";
for (auto output : outputs)
os << output->getGuid() << ",";
return os.str();
}
vector<int> LayerNormObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> LayerNormObj::getOpAttrVector() const {
return {type.underlying(), axis, stash_type};
}
} // namespace infini

View File

@ -41,4 +41,27 @@ int get_real_axis(const int &axis, const int &rank) {
}
return newAxis;
}
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) {
// check if tensor B is unidirectional broadcastable to tensor A
auto B_ = B;
int rankA = A.size();
int rankB = B.size();
if (rankA < rankB) {
return false;
}
if (rankA > rankB) {
for (auto i = 0; i < rankA - rankB; ++i) {
B_.insert(B_.begin(), 1);
}
}
for (auto i = 0; i < rankA; ++i) {
if (A[i] == B_[i] || B_[i] == 1) {
continue;
} else {
return false;
}
}
return true;
}
} // namespace infini

View File

@ -0,0 +1,146 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/layer_norm.h"
#include "test.h"
namespace infini {
void test_layernorm(
const Shape &inputShape, const vector<float> &inputData,
const Shape &scaleShape, const vector<float> &scaleData, float eps,
int axis, int stash_type, const vector<float> &ExpectData,
const std::optional<Shape> &bShape = std::nullopt,
const std::optional<std::vector<float>> &biasData = std::nullopt) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
if (bShape.has_value() && biasData.has_value()) {
Shape biasShape = *bShape;
auto bias = gCpu->addTensor(biasShape, DataType::Float32);
auto input = gCpu->addTensor(inputShape, DataType::Float32);
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
gCpu->dataMalloc();
bias->copyin(*biasData); //
// bias->printData();
input->copyin(inputData);
scale->copyin(scaleData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto biasGpu = gCuda->cloneTensor(bias);
auto inputGpu = gCuda->cloneTensor(input);
auto scaleGpu = gCuda->cloneTensor(scale);
// gCpu->cloneTensor(biasGpu)->printData();
auto op =
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, biasGpu,
eps, axis, stash_type); // LayernormObj
gCuda->dataMalloc();
biasGpu->copyin(*biasData);
// gCpu->cloneTensor(biasGpu)->printData();
inputGpu->copyin(inputData);
scaleGpu->copyin(scaleData);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
} else {
auto input = gCpu->addTensor(inputShape, DataType::Float32);
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
gCpu->dataMalloc();
input->copyin(inputData);
scale->copyin(scaleData); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
auto scaleGpu = gCuda->cloneTensor(scale);
auto op =
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, nullptr,
eps, axis, stash_type); // LayernormObj
gCuda->dataMalloc();
inputGpu->copyin(inputData);
scaleGpu->copyin(scaleData);
cudaRuntime->run(gCuda);
auto oCpu =
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
}
TEST(CUDA_Layernorm, run) {
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
vector<float>{
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
Shape{3}, vector<float>{0, 0, 0});
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
vector<float>{
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
Shape{3}, vector<float>{0.3, 0.2, 0.5});
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{1}, vector<float>{0.3}, 1e-5, 3, 1,
vector<float>{
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207},
Shape{3}, vector<float>{0.3, 0.2, 0.5});
test_layernorm(
Shape{2, 3, 2, 3},
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26.,
27., 28., 29., 30., 31., 32., 33., 34., 35.},
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
vector<float>{-0.3674207, 0.0000000, 0.6123678, -0.3674207,
0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207,
0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678,
-0.3674207, 0.0000000, 0.6123678, -0.3674207,
0.0000000, 0.6123678, -0.3674207, 0.0000000,
0.6123678, -0.3674207, 0.0000000, 0.6123678});
} // python output
} // namespace infini

View File

@ -53,10 +53,6 @@ TEST(Conv, NaiveCPU) {
i0->setData(IncrementalGenerator());
w0->setData(IncrementalGenerator());
runtime->run(g, true, true);
double perfTime = runtime->getPerfTime(g);
// The example Conv takes 0.015ms with one core
EXPECT_GT(perfTime, 0);
EXPECT_LT(perfTime, 5); // FIXME: why may it cost 4.8 ms sometimes
// check answer
auto ans =
make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::UInt32, runtime);