forked from jiuyuan/InfiniTensor
Add layer normalization (#181)
* - add layernorm kernel * success:add layernorm kernel and test * fix: remove unusalble comments * fix: modify code as reviewer suggested * debug,modified .cu and test * optional bias support * overloading function * fix bug after merging; remove time constrain in conv test --------- Co-authored-by: kilinchange <kilinchange@163.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
6ece3f4a77
commit
a7293c12ba
|
@ -30,6 +30,8 @@ class GraphHandlerObj {
|
|||
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
||||
Tensor var, Tensor scale, Tensor bias,
|
||||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw, int ceilMode);
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output, const float *bias, int biasSize);
|
||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output);
|
||||
}; // namespace infini
|
|
@ -0,0 +1,30 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class LayerNormObj : public OperatorObj {
|
||||
float eps;
|
||||
int axis, stash_type;
|
||||
|
||||
public:
|
||||
LayerNormObj(GraphObj *graph, Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias = nullptr, float eps = 1e-5, int axis = -1,
|
||||
int stash_type = 1);
|
||||
OP_CLONE(LayerNormObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
|
||||
Tensor getBias() const { return inputs.size() > 2 ? inputs[2] : nullptr; }
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getEps() const { return eps; }
|
||||
int getAxis() const { return axis; }
|
||||
int getStashType() const { return stash_type; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -10,6 +10,8 @@ namespace infini {
|
|||
Shape infer_broadcast(const Shape &A, const Shape &B);
|
||||
// Launch the real axis based on rank and current axis
|
||||
int get_real_axis(const int &axis, const int &rank);
|
||||
// check if tensor B is unidirectional broadcastable to tensor A
|
||||
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B);
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -238,6 +238,25 @@ class OnnxStub:
|
|||
eps,
|
||||
training != 0,
|
||||
)
|
||||
elif node.op_type == "LayerNormalization":
|
||||
(input, scale) = (tensors[node.input[i]] for i in [0, 1])
|
||||
bias = None if len(node.input) < 3 else tensors[node.input[2]]
|
||||
output = tensors.get(node.output[0])
|
||||
attributes = _parse_attribute(
|
||||
node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1}
|
||||
)
|
||||
(axis, eps, stash_type) = (
|
||||
attributes[name] for name in ["axis", "epsilon", "stash_type"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.layerNormalization(
|
||||
input,
|
||||
scale,
|
||||
output,
|
||||
bias,
|
||||
eps,
|
||||
axis,
|
||||
stash_type,
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/layer_norm.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
|
@ -96,6 +97,23 @@ Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
||||
Tensor output, Tensor bias,
|
||||
float eps, int axis,
|
||||
int stash_type) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<LayerNormObj>(std::move(input), std::move(scale),
|
||||
output, std::move(bias), eps, axis,
|
||||
stash_type);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<LayerNormObj>(std::move(input), std::move(scale), output,
|
||||
std::move(bias), eps, axis, stash_type)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||
int ceilMode) {
|
||||
|
|
|
@ -466,6 +466,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
#include "operators/layer_norm.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_layernorm.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class LayerNormCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LayerNormObj>(_op);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const scaleData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
const auto &opOutputShape = op->getOutput()->getDims();
|
||||
|
||||
float eps = op->getEps();
|
||||
const int axis = op->getAxis();
|
||||
const int stride = op->getInputs(0)->getStride().at(axis);
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
int dimsize = dims[op->getAxis()];
|
||||
int size = op->getOutput(0)->size();
|
||||
int scaleSize = op->getInputs(1)->size();
|
||||
if (op->numInputs() == 3) {
|
||||
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
int biasSize = op->getInputs(2)->size();
|
||||
// printf("kernel bias:true:%d\n", 1);
|
||||
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
||||
scaleSize, dimsize, stride, (float *)outputData,
|
||||
(float *)biasData, biasSize);
|
||||
} else {
|
||||
// printf("kernel bias:false:%d\n", 0);
|
||||
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
||||
scaleSize, dimsize, stride, (float *)outputData);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32,
|
||||
LayerNormCuda, "LayerNorm_CUDA_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,421 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
template <int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void blockLaynormKernel(const float *input, const float *scale,
|
||||
const int dimsize, const int stride, float *output,
|
||||
const float eps, int scaleSize, const float *bias,
|
||||
int biasSize) {
|
||||
// len(scale) = len(bias) = dimsize
|
||||
int tmp = blockIdx.x % stride;
|
||||
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||
float muPartial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float mu;
|
||||
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
mu = muBlock / dimsize;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sigma2Partial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
|
||||
__shared__ float sigma2;
|
||||
float sigma2Block =
|
||||
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
sigma2 = sigma2Block / dimsize;
|
||||
}
|
||||
__syncthreads();
|
||||
if (biasSize == dimsize) {
|
||||
if (scaleSize == dimsize) {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM];
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (scaleSize == dimsize) {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
bias[0];
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
bias[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//-----------------
|
||||
template <int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void blockLaynormKernel(const float *input, const float *scale,
|
||||
const int dimsize, const int stride, float *output,
|
||||
const float eps, int scaleSize) {
|
||||
// len(scale) = len(bias) = dimsize
|
||||
int tmp = blockIdx.x % stride;
|
||||
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||
float muPartial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float mu;
|
||||
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
mu = muBlock / dimsize;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sigma2Partial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
|
||||
__shared__ float sigma2;
|
||||
float sigma2Block =
|
||||
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
sigma2 = sigma2Block / dimsize;
|
||||
}
|
||||
__syncthreads();
|
||||
if (scaleSize == dimsize) {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
|
||||
sqrt(sigma2 + eps);
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
|
||||
sqrt(sigma2 + eps);
|
||||
}
|
||||
}
|
||||
}
|
||||
//-----------------
|
||||
template <typename T> struct SumOp {
|
||||
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <template <typename> class ReductionOp, typename T,
|
||||
int thread_group_width>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||
const int dimsize, const int stride,
|
||||
float *output, const float eps, int scaleSize,
|
||||
int otherSize, const float *bias,
|
||||
int biasSize) {
|
||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||
if (otherIdx < otherSize) {
|
||||
|
||||
__shared__ float muTotal[BLOCK_DIM_y];
|
||||
__shared__ float sigma2Total[BLOCK_DIM_y];
|
||||
|
||||
float muPartial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||
}
|
||||
|
||||
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
muTotal[threadIdx.y] = muPartial / dimsize;
|
||||
|
||||
//--------------------------------------------
|
||||
float sigma2Partial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]);
|
||||
}
|
||||
|
||||
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
|
||||
|
||||
//--------------------------------------------
|
||||
if (biasSize == dimsize) {
|
||||
if (scaleSize == dimsize) {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
|
||||
ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
|
||||
ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[0] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (scaleSize == dimsize) {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
|
||||
ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
bias[0];
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize;
|
||||
ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[0] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
bias[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||
const int dimsize, const int stride,
|
||||
float *output, const float eps, int scaleSize,
|
||||
int otherSize) {
|
||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||
if (otherIdx < otherSize) {
|
||||
|
||||
__shared__ float muTotal[BLOCK_DIM_y];
|
||||
__shared__ float sigma2Total[BLOCK_DIM_y];
|
||||
|
||||
float muPartial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||
}
|
||||
|
||||
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
muTotal[threadIdx.y] = muPartial / dimsize;
|
||||
|
||||
//--------------------------------------------
|
||||
float sigma2Partial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]);
|
||||
}
|
||||
|
||||
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
|
||||
|
||||
//--------------------------------------------
|
||||
if (scaleSize == dimsize) {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps);
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
namespace infini {
|
||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output, const float *bias, int biasSize) {
|
||||
int num_block = size / dimsize;
|
||||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<1024>
|
||||
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||
eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
}
|
||||
|
||||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output) {
|
||||
int num_block = size / dimsize;
|
||||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,64 @@
|
|||
#include "operators/layer_norm.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
LayerNormObj::LayerNormObj(GraphObj *graph, Tensor input, Tensor scale,
|
||||
Tensor output, [[maybe_unused]] Tensor bias,
|
||||
float eps, int axis_, int stash_type)
|
||||
: OperatorObj(OpType::LayerNormalization,
|
||||
bias ? TensorVec{input, scale, bias}
|
||||
: TensorVec{input, scale},
|
||||
{output}),
|
||||
eps(eps), stash_type(stash_type) {
|
||||
const auto size = input->getRank();
|
||||
axis = get_real_axis(axis_, size);
|
||||
IT_ASSERT(
|
||||
is_unidirectional_broadcasting(input->getDims(), scale->getDims()));
|
||||
if (bias) {
|
||||
IT_ASSERT(
|
||||
is_unidirectional_broadcasting(input->getDims(), bias->getDims()));
|
||||
}
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> LayerNormObj::inferShape(const TensorVec &inputs) {
|
||||
return {{inputs[0]->getDims()}};
|
||||
}
|
||||
|
||||
vector<DataType> LayerNormObj::inferDataType(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
||||
IT_ASSERT(inputs[1]->getDType() == DataType::Float32);
|
||||
if (inputs.size() == 3) {
|
||||
IT_ASSERT(inputs[2]->getDType() == DataType::Float32);
|
||||
}
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
std::string LayerNormObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "layerNormalization[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "axis=" << axis << ",";
|
||||
os << "eps=" << eps << ",";
|
||||
os << "stash_type=" << stash_type << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "scale=" << inputs[1]->getGuid() << ",";
|
||||
// os << "bias=" << inputs[2]->getGuid() << ",";
|
||||
os << "output=";
|
||||
for (auto output : outputs)
|
||||
os << output->getGuid() << ",";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> LayerNormObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> LayerNormObj::getOpAttrVector() const {
|
||||
return {type.underlying(), axis, stash_type};
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -41,4 +41,27 @@ int get_real_axis(const int &axis, const int &rank) {
|
|||
}
|
||||
return newAxis;
|
||||
}
|
||||
|
||||
bool is_unidirectional_broadcasting(const Shape &A, const Shape &B) {
|
||||
// check if tensor B is unidirectional broadcastable to tensor A
|
||||
auto B_ = B;
|
||||
int rankA = A.size();
|
||||
int rankB = B.size();
|
||||
if (rankA < rankB) {
|
||||
return false;
|
||||
}
|
||||
if (rankA > rankB) {
|
||||
for (auto i = 0; i < rankA - rankB; ++i) {
|
||||
B_.insert(B_.begin(), 1);
|
||||
}
|
||||
}
|
||||
for (auto i = 0; i < rankA; ++i) {
|
||||
if (A[i] == B_[i] || B_[i] == 1) {
|
||||
continue;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/layer_norm.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_layernorm(
|
||||
const Shape &inputShape, const vector<float> &inputData,
|
||||
const Shape &scaleShape, const vector<float> &scaleData, float eps,
|
||||
int axis, int stash_type, const vector<float> &ExpectData,
|
||||
const std::optional<Shape> &bShape = std::nullopt,
|
||||
const std::optional<std::vector<float>> &biasData = std::nullopt) {
|
||||
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
if (bShape.has_value() && biasData.has_value()) {
|
||||
Shape biasShape = *bShape;
|
||||
|
||||
auto bias = gCpu->addTensor(biasShape, DataType::Float32);
|
||||
auto input = gCpu->addTensor(inputShape, DataType::Float32);
|
||||
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
bias->copyin(*biasData); //
|
||||
// bias->printData();
|
||||
input->copyin(inputData);
|
||||
scale->copyin(scaleData); //
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto biasGpu = gCuda->cloneTensor(bias);
|
||||
auto inputGpu = gCuda->cloneTensor(input);
|
||||
auto scaleGpu = gCuda->cloneTensor(scale);
|
||||
// gCpu->cloneTensor(biasGpu)->printData();
|
||||
auto op =
|
||||
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, biasGpu,
|
||||
eps, axis, stash_type); // LayernormObj
|
||||
gCuda->dataMalloc();
|
||||
biasGpu->copyin(*biasData);
|
||||
// gCpu->cloneTensor(biasGpu)->printData();
|
||||
inputGpu->copyin(inputData);
|
||||
scaleGpu->copyin(scaleData);
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu =
|
||||
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
} else {
|
||||
|
||||
auto input = gCpu->addTensor(inputShape, DataType::Float32);
|
||||
auto scale = gCpu->addTensor(scaleShape, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
|
||||
input->copyin(inputData);
|
||||
scale->copyin(scaleData); //
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputGpu = gCuda->cloneTensor(input);
|
||||
auto scaleGpu = gCuda->cloneTensor(scale);
|
||||
auto op =
|
||||
gCuda->addOp<LayerNormObj>(inputGpu, scaleGpu, nullptr, nullptr,
|
||||
eps, axis, stash_type); // LayernormObj
|
||||
gCuda->dataMalloc();
|
||||
|
||||
inputGpu->copyin(inputData);
|
||||
scaleGpu->copyin(scaleData);
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu =
|
||||
gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CUDA_Layernorm, run) {
|
||||
test_layernorm(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
|
||||
vector<float>{
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207, 0.0000000, 0.6123678},
|
||||
Shape{3}, vector<float>{0, 0, 0});
|
||||
test_layernorm(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
|
||||
vector<float>{
|
||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679,
|
||||
-0.0674207, 0.2000000, 1.1123679, -0.0674207, 0.2000000, 1.1123679},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
||||
test_layernorm(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{1}, vector<float>{0.3}, 1e-5, 3, 1,
|
||||
vector<float>{
|
||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207,
|
||||
-0.0674207, 0.2000000, 0.8674207, -0.0674207, 0.2000000, 0.8674207},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5});
|
||||
test_layernorm(
|
||||
Shape{2, 3, 2, 3},
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8.,
|
||||
9., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26.,
|
||||
27., 28., 29., 30., 31., 32., 33., 34., 35.},
|
||||
Shape{3}, vector<float>{0.3, 0.2, 0.5}, 1e-5, 3, 1,
|
||||
vector<float>{-0.3674207, 0.0000000, 0.6123678, -0.3674207,
|
||||
0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||
0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207,
|
||||
0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||
0.6123678, -0.3674207, 0.0000000, 0.6123678,
|
||||
-0.3674207, 0.0000000, 0.6123678, -0.3674207,
|
||||
0.0000000, 0.6123678, -0.3674207, 0.0000000,
|
||||
0.6123678, -0.3674207, 0.0000000, 0.6123678});
|
||||
|
||||
} // python output
|
||||
|
||||
} // namespace infini
|
|
@ -53,10 +53,6 @@ TEST(Conv, NaiveCPU) {
|
|||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
runtime->run(g, true, true);
|
||||
double perfTime = runtime->getPerfTime(g);
|
||||
// The example Conv takes 0.015ms with one core
|
||||
EXPECT_GT(perfTime, 0);
|
||||
EXPECT_LT(perfTime, 5); // FIXME: why may it cost 4.8 ms sometimes
|
||||
// check answer
|
||||
auto ans =
|
||||
make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::UInt32, runtime);
|
||||
|
|
Loading…
Reference in New Issue