diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 8c4f59bc..76f6e0c2 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -30,6 +30,8 @@ class GraphHandlerObj { Tensor batchNormalization(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor scale, Tensor bias, float momentum, float eps, bool training); + Tensor layerNormalization(Tensor input, Tensor scale, Tensor output, + Tensor bias, float eps, int axis, int stash_type); Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode); diff --git a/include/cuda/cuda_layernorm.h b/include/cuda/cuda_layernorm.h new file mode 100644 index 00000000..997c8a06 --- /dev/null +++ b/include/cuda/cuda_layernorm.h @@ -0,0 +1,11 @@ +#pragma once +#include "operators/unary.h" + +namespace infini { +void LaynormKernel(const float *input, const float *scale, const float eps, + int size, int scaleSize, const int dimsize, const int stride, + float *output, const float *bias, int biasSize); +void LaynormKernel(const float *input, const float *scale, const float eps, + int size, int scaleSize, const int dimsize, const int stride, + float *output); +}; // namespace infini diff --git a/include/operators/layer_norm.h b/include/operators/layer_norm.h new file mode 100644 index 00000000..8534648f --- /dev/null +++ b/include/operators/layer_norm.h @@ -0,0 +1,30 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class LayerNormObj : public OperatorObj { + float eps; + int axis, stash_type; + + public: + LayerNormObj(GraphObj *graph, Tensor input, Tensor scale, Tensor output, + Tensor bias = nullptr, float eps = 1e-5, int axis = -1, + int stash_type = 1); + OP_CLONE(LayerNormObj); + optional> inferShape(const TensorVec &inputs) override; + std::string toString() const override; + + Tensor getBias() const { return inputs.size() > 2 ? inputs[2] : nullptr; } + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return outputs.size(); } + float getEps() const { return eps; } + int getAxis() const { return axis; } + int getStashType() const { return stash_type; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + + vector inferDataType(const TensorVec &inputs) const override; +}; +} // namespace infini diff --git a/include/utils/operator_utils.h b/include/utils/operator_utils.h index 01703252..4f6a6985 100644 --- a/include/utils/operator_utils.h +++ b/include/utils/operator_utils.h @@ -10,6 +10,8 @@ namespace infini { Shape infer_broadcast(const Shape &A, const Shape &B); // Launch the real axis based on rank and current axis int get_real_axis(const int &axis, const int &rank); +// check if tensor B is unidirectional broadcastable to tensor A +bool is_unidirectional_broadcasting(const Shape &A, const Shape &B); } // namespace infini #endif diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ad842d5b..f0326d88 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -238,6 +238,25 @@ class OnnxStub: eps, training != 0, ) + elif node.op_type == "LayerNormalization": + (input, scale) = (tensors[node.input[i]] for i in [0, 1]) + bias = None if len(node.input) < 3 else tensors[node.input[2]] + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1} + ) + (axis, eps, stash_type) = ( + attributes[name] for name in ["axis", "epsilon", "stash_type"] + ) + tensors[node.output[0]] = self.handler.layerNormalization( + input, + scale, + output, + bias, + eps, + axis, + stash_type, + ) elif node.op_type == "MaxPool": attributes = _parse_attribute( node, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index fdceba62..de156c43 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -9,6 +9,7 @@ #include "operators/element_wise.h" #include "operators/expand.h" #include "operators/gather.h" +#include "operators/layer_norm.h" #include "operators/matmul.h" #include "operators/pad.h" #include "operators/pooling.h" @@ -96,6 +97,23 @@ Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output, } } +Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale, + Tensor output, Tensor bias, + float eps, int axis, + int stash_type) { + if (output) { + g->addOpWithOutputs(std::move(input), std::move(scale), + output, std::move(bias), eps, axis, + stash_type); + return output; + } else { + return g + ->addOp(std::move(input), std::move(scale), output, + std::move(bias), eps, axis, stash_type) + ->getOutput(); + } +} + Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 0bdfdcf9..408d3514 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -466,6 +466,7 @@ void init_graph_builder(py::module &m) { .def("convTransposed2d", &Handler::convTransposed2d, policy::move) .def("matmul", &Handler::matmul, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move) + .def("layerNormalization", &Handler::layerNormalization, policy::move) .def("maxPool", &Handler::maxPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move) .def("add", &Handler::add, policy::move) diff --git a/src/kernels/cuda/layer_norm.cc b/src/kernels/cuda/layer_norm.cc new file mode 100644 index 00000000..a301eb0b --- /dev/null +++ b/src/kernels/cuda/layer_norm.cc @@ -0,0 +1,45 @@ +#include "operators/layer_norm.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_layernorm.h" +#include "cuda/cuda_runtime.h" + +namespace infini { + +class LayerNormCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const scaleData = (op->getInputs(1)->getRawDataPtr()); + + void *const outputData = (op->getOutput()->getRawDataPtr()); + const auto &opOutputShape = op->getOutput()->getDims(); + + float eps = op->getEps(); + const int axis = op->getAxis(); + const int stride = op->getInputs(0)->getStride().at(axis); + + auto dims = op->getInputs(0)->getDims(); + int dimsize = dims[op->getAxis()]; + int size = op->getOutput(0)->size(); + int scaleSize = op->getInputs(1)->size(); + if (op->numInputs() == 3) { + void *const biasData = (op->getInputs(2)->getRawDataPtr()); + int biasSize = op->getInputs(2)->size(); + // printf("kernel bias:true:%d\n", 1); + LaynormKernel((float *)inputData, (float *)scaleData, eps, size, + scaleSize, dimsize, stride, (float *)outputData, + (float *)biasData, biasSize); + } else { + // printf("kernel bias:false:%d\n", 0); + LaynormKernel((float *)inputData, (float *)scaleData, eps, size, + scaleSize, dimsize, stride, (float *)outputData); + } + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32, + LayerNormCuda, "LayerNorm_CUDA_Float32"); + +}; // namespace infini diff --git a/src/kernels/cuda/layer_norm.cu b/src/kernels/cuda/layer_norm.cu new file mode 100644 index 00000000..c5e6e492 --- /dev/null +++ b/src/kernels/cuda/layer_norm.cu @@ -0,0 +1,421 @@ +#include "cuda/cuda_common.h" +#include + +template +__launch_bounds__(BLOCK_DIM) __global__ + void blockLaynormKernel(const float *input, const float *scale, + const int dimsize, const int stride, float *output, + const float eps, int scaleSize, const float *bias, + int biasSize) { + // len(scale) = len(bias) = dimsize + int tmp = blockIdx.x % stride; + int tid = (blockIdx.x - tmp) * dimsize + tmp; + float muPartial = 0.0f; + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float mu; + float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + mu = muBlock / dimsize; + } + __syncthreads(); + + float sigma2Partial = 0.0f; + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + sigma2Partial += + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); + } + typedef cub::BlockReduce BlockReduce; + + __shared__ float sigma2; + float sigma2Block = + BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + sigma2 = sigma2Block / dimsize; + } + __syncthreads(); + if (biasSize == dimsize) { + if (scaleSize == dimsize) { + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + scale[threadIdx.x + ph * BLOCK_DIM] * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - + mu) / + sqrt(sigma2 + eps) + + bias[threadIdx.x + ph * BLOCK_DIM]; + } + } else { + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + scale[0] * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - + mu) / + sqrt(sigma2 + eps) + + bias[threadIdx.x + ph * BLOCK_DIM]; + } + } + } else { + if (scaleSize == dimsize) { + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + scale[threadIdx.x + ph * BLOCK_DIM] * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - + mu) / + sqrt(sigma2 + eps) + + bias[0]; + } + } else { + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + scale[0] * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - + mu) / + sqrt(sigma2 + eps) + + bias[0]; + } + } + } +} +//----------------- +template +__launch_bounds__(BLOCK_DIM) __global__ + void blockLaynormKernel(const float *input, const float *scale, + const int dimsize, const int stride, float *output, + const float eps, int scaleSize) { + // len(scale) = len(bias) = dimsize + int tmp = blockIdx.x % stride; + int tid = (blockIdx.x - tmp) * dimsize + tmp; + float muPartial = 0.0f; + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float mu; + float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + mu = muBlock / dimsize; + } + __syncthreads(); + + float sigma2Partial = 0.0f; + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + sigma2Partial += + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu); + } + typedef cub::BlockReduce BlockReduce; + + __shared__ float sigma2; + float sigma2Block = + BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + sigma2 = sigma2Block / dimsize; + } + __syncthreads(); + if (scaleSize == dimsize) { + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + scale[threadIdx.x + ph * BLOCK_DIM] * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / + sqrt(sigma2 + eps); + } + } else { + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + scale[0] * + (input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) / + sqrt(sigma2 + eps); + } + } +} +//----------------- +template struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template