From a4d64265894c9b5682515ca1292484c2dfcf8212 Mon Sep 17 00:00:00 2001 From: wendy12022 Date: Sat, 15 Oct 2022 16:29:28 +0800 Subject: [PATCH] ADD: batch norm operator and cuda kernel. (#44) fix numInputs of batchNorm, add new line in file ending. ADD: batch norm operator and cuda kernel. add training remove comments. fix compile error. add batch norm operator and cuda kernel. --- include/cuda/cuda_common.h | 2 +- include/cuda/cuda_element_wise.h | 2 +- include/cuda/cuda_kernel_wihtout_config.h | 2 +- include/cuda/cuda_pad_slice.h | 2 +- include/cuda/cuda_runtime.h | 2 +- include/cuda/cuda_split_concat.h | 2 +- include/cuda/gather.h | 2 +- include/cuda/operator_timer.h | 2 +- include/operators/G2BMM.h | 2 +- include/operators/GBMM.h | 2 +- include/operators/batch_norm.h | 28 ++++++++ include/operators/concat.h | 2 +- include/operators/element_wise.h | 2 +- include/operators/extend.h | 2 +- include/operators/gather.h | 2 +- include/operators/pad.h | 2 +- include/operators/pooling.h | 2 +- include/utils/data_generator.h | 10 +-- src/kernels/cuda/GBMM.cc | 2 +- src/kernels/cuda/batch_norm.cc | 64 +++++++++++++++++ src/kernels/cuda/conv.cc | 2 +- src/kernels/cuda/conv_transposed.cc | 2 +- src/kernels/cuda/element_wise.cc | 2 +- src/kernels/cuda/element_wise.cu | 2 +- src/kernels/cuda/extend.cu | 2 +- src/kernels/cuda/gather.cc | 46 ++++++------ src/kernels/cuda/gather.cu | 2 +- src/kernels/cuda/gbmm_g2bmm.cu | 1 - src/kernels/cuda/matmul.cc | 2 +- src/kernels/cuda/membound_TVM.cc | 2 +- src/kernels/cuda/pooling.cc | 2 +- src/kernels/cuda/reshape.cc | 2 +- src/kernels/cuda/split_concat.cc | 58 +++++++-------- src/kernels/cuda/split_concat.cu | 2 +- src/operators/G2BMM.cc | 2 +- src/operators/batch_norm.cc | 72 +++++++++++++++++++ src/operators/concat.cc | 2 +- src/operators/conv.cc | 2 +- src/operators/element_wise.cc | 2 +- src/operators/extend.cc | 2 +- src/operators/gather.cc | 2 +- src/operators/matmul.cc | 2 +- src/operators/membound.cc | 2 +- src/operators/pooling.cc | 2 +- src/operators/reshape.cc | 2 +- src/operators/split.cc | 2 +- test/kernels/cuda/test_cuda_GBMM.cc | 2 +- test/kernels/cuda/test_cuda_batch_norm.cc | 54 ++++++++++++++ test/kernels/cuda/test_cuda_concat.cc | 2 +- test/kernels/cuda/test_cuda_conv.cc | 2 +- .../cuda/test_cuda_conv_transposed_2d.cc | 2 +- test/kernels/cuda/test_cuda_element_wise.cc | 2 +- test/kernels/cuda/test_cuda_extend.cc | 2 +- test/kernels/cuda/test_cuda_gather.cc | 2 +- test/kernels/cuda/test_cuda_matmul.cc | 2 +- test/kernels/cuda/test_cuda_pad.cc | 2 +- test/kernels/cuda/test_cuda_pooling.cc | 2 +- test/kernels/cuda/test_cuda_reshape.cc | 2 +- test/kernels/cuda/test_cuda_slice.cc | 2 +- test/kernels/cuda/test_cuda_split.cc | 2 +- test/kernels/cuda/test_perfengine.cc | 2 +- test/operators/test_batch_norm.cc | 21 ++++++ test/operators/test_concat.cc | 2 +- test/operators/test_conv.cc | 2 +- test/operators/test_conv_transposed_2d.cc | 2 +- test/operators/test_element_wise.cc | 2 +- test/operators/test_extend.cc | 2 +- test/operators/test_gather.cc | 2 +- test/operators/test_matmul.cc | 2 +- test/operators/test_pad.cc | 2 +- test/operators/test_pooling.cc | 2 +- test/operators/test_reshape.cc | 2 +- test/operators/test_slice.cc | 2 +- test/operators/test_split.cc | 2 +- 74 files changed, 362 insertions(+), 122 deletions(-) create mode 100644 include/operators/batch_norm.h create mode 100644 src/kernels/cuda/batch_norm.cc create mode 100644 src/operators/batch_norm.cc create mode 100644 test/kernels/cuda/test_cuda_batch_norm.cc create mode 100644 test/operators/test_batch_norm.cc diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index c24ecdca..7d3bb65d 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -121,4 +121,4 @@ inline const char *curandGetErrorString(curandStatus_t error) { using CudaPtr = void *; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h index d51a04cf..39de0ea3 100644 --- a/include/cuda/cuda_element_wise.h +++ b/include/cuda/cuda_element_wise.h @@ -3,4 +3,4 @@ namespace infini { void div_kernel(float *a, float *b, float *c, int num); void pow_kernel(float *a, float *b, float *c, int num); -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/include/cuda/cuda_kernel_wihtout_config.h b/include/cuda/cuda_kernel_wihtout_config.h index ff503f45..6a98ca9e 100644 --- a/include/cuda/cuda_kernel_wihtout_config.h +++ b/include/cuda/cuda_kernel_wihtout_config.h @@ -21,4 +21,4 @@ class CudaKernelWithoutConfig : public Kernel { } }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/cuda/cuda_pad_slice.h b/include/cuda/cuda_pad_slice.h index 9a452691..db032fa0 100644 --- a/include/cuda/cuda_pad_slice.h +++ b/include/cuda/cuda_pad_slice.h @@ -16,4 +16,4 @@ namespace infini { void pad_slice_kernel(float *partData, float *wholeData, const TransMetaData &metadata, int nDims, int num, bool isPad); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index efb9a9e2..96e25f9b 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -74,4 +74,4 @@ class CudaRuntimeObj : public RuntimeObj { private: void tune(const Graph &graph, bool profiling) const; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/cuda/cuda_split_concat.h b/include/cuda/cuda_split_concat.h index 454a8f52..699f15bc 100644 --- a/include/cuda/cuda_split_concat.h +++ b/include/cuda/cuda_split_concat.h @@ -32,4 +32,4 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta, const ComposedTensorMetadata &compMeta, int dim, int batchSize, int nDims, bool isSplit); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/cuda/gather.h b/include/cuda/gather.h index 80910117..0cf45142 100644 --- a/include/cuda/gather.h +++ b/include/cuda/gather.h @@ -14,4 +14,4 @@ typedef struct { namespace infini { void gather_kernel(float *in, float *out, GatherMetaData metaData, int num); -} \ No newline at end of file +} diff --git a/include/cuda/operator_timer.h b/include/cuda/operator_timer.h index b9d49e13..f887770d 100644 --- a/include/cuda/operator_timer.h +++ b/include/cuda/operator_timer.h @@ -13,4 +13,4 @@ double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r, double getPerfMatmulCublas(int b, int m, int n, int k, const char *name); } // namespace opTimer -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/G2BMM.h b/include/operators/G2BMM.h index d44f0494..984c34eb 100644 --- a/include/operators/G2BMM.h +++ b/include/operators/G2BMM.h @@ -48,4 +48,4 @@ class G2BMMObj : public OperatorObj { vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/GBMM.h b/include/operators/GBMM.h index c623b8b1..836e9ba4 100644 --- a/include/operators/GBMM.h +++ b/include/operators/GBMM.h @@ -46,4 +46,4 @@ class GBMMObj : public OperatorObj { vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/batch_norm.h b/include/operators/batch_norm.h new file mode 100644 index 00000000..ac65c6df --- /dev/null +++ b/include/operators/batch_norm.h @@ -0,0 +1,28 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class BatchNormObj : public OperatorObj { + float momentum, eps; + bool training; + + public: + BatchNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean, + Tensor var, Tensor scale, Tensor bias, float momentum = 0.9, + float eps = 1e-5, bool training = false); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + + // output size will be 3 when training + int numInputs() const override { return 5; } + int numOutputs() const override { return outputs.size(); } + float getEps() const { return eps; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; + + vector inferDataType(const TensorVec &inputs) const override; +}; +} // namespace infini diff --git a/include/operators/concat.h b/include/operators/concat.h index ebca158c..c36e94ce 100644 --- a/include/operators/concat.h +++ b/include/operators/concat.h @@ -19,4 +19,4 @@ class ConcatObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/element_wise.h b/include/operators/element_wise.h index 7111d50a..93a48f70 100644 --- a/include/operators/element_wise.h +++ b/include/operators/element_wise.h @@ -30,4 +30,4 @@ DEFINE_ELEMENT_WISE_OBJ(Sub, OpType::Sub) DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul) DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div) DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow) -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/include/operators/extend.h b/include/operators/extend.h index d3ef64fe..d22afad4 100644 --- a/include/operators/extend.h +++ b/include/operators/extend.h @@ -20,4 +20,4 @@ class ExtendObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/gather.h b/include/operators/gather.h index 1a93f2af..b7356e0e 100644 --- a/include/operators/gather.h +++ b/include/operators/gather.h @@ -21,4 +21,4 @@ class GatherObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/pad.h b/include/operators/pad.h index d60443eb..66e65d43 100644 --- a/include/operators/pad.h +++ b/include/operators/pad.h @@ -21,4 +21,4 @@ class PadObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/operators/pooling.h b/include/operators/pooling.h index b0336b46..6eb62ead 100644 --- a/include/operators/pooling.h +++ b/include/operators/pooling.h @@ -51,4 +51,4 @@ class AvgPoolObj : public PoolingObj { : PoolingObj(graph, OpType::AvgPool, input, output, kh, kw, dh, dw, ph, pw, sh, sw) {} }; -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/include/utils/data_generator.h b/include/utils/data_generator.h index 6ae8266e..6a106d2e 100644 --- a/include/utils/data_generator.h +++ b/include/utils/data_generator.h @@ -38,14 +38,14 @@ class IncrementalGenerator : public DataGenerator { void fill(float *data, size_t size) override { fill(data, size); } }; -class OneGenerator : public DataGenerator { +template class ValGenerator : public DataGenerator { public: - virtual ~OneGenerator() {} + virtual ~ValGenerator() {} private: template void fill(T *data, size_t size) { for (size_t i = 0; i < size; i++) { - data[i] = 1; + data[i] = val; } } @@ -54,4 +54,6 @@ class OneGenerator : public DataGenerator { } void fill(float *data, size_t size) override { fill(data, size); } }; -} // namespace infini \ No newline at end of file +typedef ValGenerator<1> OneGenerator; +typedef ValGenerator<0> ZeroGenerator; +} // namespace infini diff --git a/src/kernels/cuda/GBMM.cc b/src/kernels/cuda/GBMM.cc index 303e4553..06002850 100644 --- a/src/kernels/cuda/GBMM.cc +++ b/src/kernels/cuda/GBMM.cc @@ -58,4 +58,4 @@ class GBMMCudnn : public CudaKernelWithoutConfig { REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn, "GBMM_cuDNN_CUDA_Float32"); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/batch_norm.cc b/src/kernels/cuda/batch_norm.cc new file mode 100644 index 00000000..7ca75ba6 --- /dev/null +++ b/src/kernels/cuda/batch_norm.cc @@ -0,0 +1,64 @@ +#include "operators/batch_norm.h" +#include "core/kernel.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +namespace infini { +class BatchNormCudnn : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + cudnnStatus_t stat; + void *const inData = (op->getInputs(0)->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); + void *const meanData = (op->getInputs(1)->getRawDataPtr()); + void *const varData = (op->getInputs(2)->getRawDataPtr()); + void *const scaleData = (op->getInputs(3)->getRawDataPtr()); + void *const biasData = (op->getInputs(4)->getRawDataPtr()); + + auto dims = op->getInputs(0)->getDims(); + if (dims.size() == 2) + IT_TODO_HALT(); + // Only 4D and 5D tensors are supported by + // cudnnBatchNormalizationForwardInference + IT_ASSERT(dims.size() == 4 || dims.size() == 5); + + int dimArray[CUDNN_DIM_MAX], strideArray[CUDNN_DIM_MAX], + dimPArray[CUDNN_DIM_MAX], stridePArray[CUDNN_DIM_MAX]; + for (size_t i = 0; i < dims.size(); ++i) { + dimArray[i] = dims[i]; + strideArray[i] = op->getInputs(0)->getStride()[i]; + dimPArray[i] = op->getInputs(1)->getDims()[i]; + stridePArray[i] = op->getInputs(1)->getStride()[i]; + } + // get inputs + cudnnTensorDescriptor_t inDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); + checkCudnnError(cudnnSetTensorNdDescriptor( + inDesc, CUDNN_DATA_FLOAT, dims.size(), dimArray, strideArray)); + + // get bnScaleBiasMeanVarDesc + cudnnTensorDescriptor_t paraDesc; + checkCudnnError(cudnnCreateTensorDescriptor(¶Desc)); + checkCudnnError(cudnnSetTensorNdDescriptor( + paraDesc, CUDNN_DATA_FLOAT, dims.size(), dimPArray, stridePArray)); + + float alpha = 1.f, beta = 0.f; + // This mode is intended for use after convolutional layers + stat = cudnnBatchNormalizationForwardInference( + context->cudnnHandle(), CUDNN_BATCHNORM_SPATIAL, &alpha, &beta, + inDesc, inData, inDesc, outData, paraDesc, scaleData, biasData, + meanData, varData, op->getEps()); + if (stat != CUDNN_STATUS_SUCCESS) + return; + + // Destories in CUDA does not require sync. But cuDNN does not state + // whether sync is required before destories. + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(paraDesc)); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::BatchNorm, DataType::Float32, + BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32"); +} // namespace infini diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index e80a0b1e..f6afa20d 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -294,4 +294,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn, "Conv_cuDNN_CUDA_Float32"); REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/conv_transposed.cc b/src/kernels/cuda/conv_transposed.cc index c20eb3fa..11f4599d 100644 --- a/src/kernels/cuda/conv_transposed.cc +++ b/src/kernels/cuda/conv_transposed.cc @@ -285,4 +285,4 @@ class convBackwardDataCudnn : public Kernel { REGISTER_KERNEL(Device::CUDA, OpType::ConvTrans, DataType::Float32, convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32"); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 374fcf5d..c44287e7 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -105,4 +105,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda, "Div_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda, "Pow__CUDA_Float32"); -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu index afe429f4..be7d4495 100644 --- a/src/kernels/cuda/element_wise.cu +++ b/src/kernels/cuda/element_wise.cu @@ -35,4 +35,4 @@ void pow_kernel(float *a, float *b, float *c, int num) { _pow_kernel<<>>(a, b, c, num); } -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/kernels/cuda/extend.cu b/src/kernels/cuda/extend.cu index 05cf95cb..03345e96 100644 --- a/src/kernels/cuda/extend.cu +++ b/src/kernels/cuda/extend.cu @@ -22,4 +22,4 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter, _extend_kernel<<>>(in, out, blockSize, blockSizeOuter, oSize); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/gather.cc b/src/kernels/cuda/gather.cc index 5be2767e..63725524 100644 --- a/src/kernels/cuda/gather.cc +++ b/src/kernels/cuda/gather.cc @@ -4,30 +4,30 @@ #include "cuda/gather.h" namespace infini { - -void initGatherMetaData(GatherMetaData &metaData, const Operator &_op) { - memset(&metaData, 0, sizeof(metaData)); - auto op = as(_op); - auto in = op->getInputs(0); - auto index = op->getInputs(1); - auto out = op->getOutput(); - metaData.indexValue = index->getRawDataPtr(); - metaData.axis = op->getAxis(); - metaData.inNDim = in->getDims().size(); - metaData.outNDim = out->getDims().size(); - metaData.idxNDim = index->getDims().size(); - for (int i = 0; i < metaData.outNDim; ++i) - metaData.outDim[i] = out->getDims()[i]; - for (int i = 0; i < metaData.idxNDim; ++i) { - metaData.idxDim[i] = index->getDims()[i]; - metaData.idxStride[i] = index->getStride()[i]; - } - for (int i = 0; i < metaData.inNDim; ++i) { - metaData.inStride[i] = in->getStride()[i]; - } -} - class GatherCuda : public CudaKernelWithoutConfig { + void initGatherMetaData(GatherMetaData &metaData, + const Operator &_op) const { + memset(&metaData, 0, sizeof(metaData)); + auto op = as(_op); + auto in = op->getInputs(0); + auto index = op->getInputs(1); + auto out = op->getOutput(); + metaData.indexValue = index->getRawDataPtr(); + metaData.axis = op->getAxis(); + metaData.inNDim = in->getDims().size(); + metaData.outNDim = out->getDims().size(); + metaData.idxNDim = index->getDims().size(); + for (int i = 0; i < metaData.outNDim; ++i) + metaData.outDim[i] = out->getDims()[i]; + for (int i = 0; i < metaData.idxNDim; ++i) { + metaData.idxDim[i] = index->getDims()[i]; + metaData.idxStride[i] = index->getStride()[i]; + } + for (int i = 0; i < metaData.inNDim; ++i) { + metaData.inStride[i] = in->getStride()[i]; + } + } + void compute(const Operator &op, const RuntimeObj *_context) const override { diff --git a/src/kernels/cuda/gather.cu b/src/kernels/cuda/gather.cu index 00b382d7..0ae69085 100644 --- a/src/kernels/cuda/gather.cu +++ b/src/kernels/cuda/gather.cu @@ -44,4 +44,4 @@ void gather_kernel(float *in, float *out, GatherMetaData metaData, int num) { _gather_kernel<<>>(in, out, metaData, num); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/gbmm_g2bmm.cu b/src/kernels/cuda/gbmm_g2bmm.cu index 0002e021..3c5b6399 100644 --- a/src/kernels/cuda/gbmm_g2bmm.cu +++ b/src/kernels/cuda/gbmm_g2bmm.cu @@ -14,4 +14,3 @@ void _sgbmml(float *__restrict__ q, float *__restrict__ k, } } // namespace infini - diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index 6add86aa..cee15379 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -102,4 +102,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Matmul, DataType::Float32, matmulCublas, "Matmul_cuBLAS_CUDA_Float32"); REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json); -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/kernels/cuda/membound_TVM.cc b/src/kernels/cuda/membound_TVM.cc index 6bf4c4b2..91a8b075 100644 --- a/src/kernels/cuda/membound_TVM.cc +++ b/src/kernels/cuda/membound_TVM.cc @@ -188,4 +188,4 @@ class MemboundTVM : public Kernel { REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, MemboundTVM, "Memobund_TVM_Ansor"); -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/kernels/cuda/pooling.cc b/src/kernels/cuda/pooling.cc index c11e58e3..47cf32e7 100644 --- a/src/kernels/cuda/pooling.cc +++ b/src/kernels/cuda/pooling.cc @@ -72,4 +72,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn, "MaxPool_cuDNN_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::AvgPool, DataType::Float32, avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32"); -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index 7beb4798..77070c23 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -18,4 +18,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, "Identity_CUDA_Float32"); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc index b7a12e5a..5792753c 100644 --- a/src/kernels/cuda/split_concat.cc +++ b/src/kernels/cuda/split_concat.cc @@ -6,35 +6,35 @@ namespace infini { -void initComposedTensorMetadata(ComposedTensorMetadata &metadata, - Tensor tensor) { - int nDims = tensor->getDims().size(); - auto strides = tensor->getStride(); - IT_ASSERT(strides.size() == (size_t)nDims); - for (int i = 0; i < nDims; ++i) { - metadata.dimSize[i] = tensor->getDims().at(i); - metadata.stride[i] = strides.at(i); - } - metadata.data = tensor->getRawDataPtr(); -} - -void initElementTensorMetadata(ElementTensorMetadata &metadata, - TensorVec tensors, int idx, int dim, - int &dimBgIdx, int &batchCounter) { - int nTensors = tensors.size(); - for (; batchCounter < BATCH_SIZE && idx + batchCounter < nTensors; - ++batchCounter) { - auto tensor = tensors.at(idx + batchCounter); - auto dimSize = tensor->getDims()[dim]; - metadata.data[batchCounter] = tensor->getRawDataPtr(); - metadata.dimBgNo[batchCounter] = dimBgIdx; - metadata.dimSize[batchCounter] = dimSize; - metadata.nElements[batchCounter] = tensor->size(); - dimBgIdx += dimSize; - } -} - class CudaCompute { + void initComposedTensorMetadata(ComposedTensorMetadata &metadata, + Tensor tensor) const { + int nDims = tensor->getDims().size(); + auto strides = tensor->getStride(); + IT_ASSERT(strides.size() == (size_t)nDims); + for (int i = 0; i < nDims; ++i) { + metadata.dimSize[i] = tensor->getDims().at(i); + metadata.stride[i] = strides.at(i); + } + metadata.data = tensor->getRawDataPtr(); + } + + void initElementTensorMetadata(ElementTensorMetadata &metadata, + TensorVec tensors, int idx, int dim, + int &dimBgIdx, int &batchCounter) const { + int nTensors = tensors.size(); + for (; batchCounter < BATCH_SIZE && idx + batchCounter < nTensors; + ++batchCounter) { + auto tensor = tensors.at(idx + batchCounter); + auto dimSize = tensor->getDims()[dim]; + metadata.data[batchCounter] = tensor->getRawDataPtr(); + metadata.dimBgNo[batchCounter] = dimBgIdx; + metadata.dimSize[batchCounter] = dimSize; + metadata.nElements[batchCounter] = tensor->size(); + dimBgIdx += dimSize; + } + } + public: void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim, int nDims, bool isSplit) const { @@ -78,4 +78,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda, "Concat_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda, "Split_CUDA_Float32"); -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index e71e7890..86f5524b 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -68,4 +68,4 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta, isSplit); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/G2BMM.cc b/src/operators/G2BMM.cc index 0cc22d8f..80d63482 100644 --- a/src/operators/G2BMM.cc +++ b/src/operators/G2BMM.cc @@ -46,4 +46,4 @@ vector G2BMMObj::getOpAttrVector() const { return {enum_to_underlying(type), width, dilation, enum_to_underlying(act)}; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/batch_norm.cc b/src/operators/batch_norm.cc new file mode 100644 index 00000000..b9be9425 --- /dev/null +++ b/src/operators/batch_norm.cc @@ -0,0 +1,72 @@ +#include "operators/batch_norm.h" + +namespace infini { +BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output, + Tensor mean, Tensor var, Tensor scale, Tensor bias, + float momentum, float eps, bool training) + : OperatorObj(OpType::BatchNorm, {input, mean, var, scale, bias}, {output}), + momentum(momentum), eps(eps), training(training) { + if (training) + IT_TODO_HALT(); + + IT_ASSERT(checkValid(graph)); +} + +optional> +BatchNormObj::inferShape(const TensorVec &inputs) const { + auto input = inputs[0]; + auto mean = inputs[1]; + auto var = inputs[2]; + auto scale = inputs[3]; + auto bias = inputs[4]; + if (input->getDims().size() < 2) + return {}; + Shape dims(input->getDims().size(), 1); + dims[1] = input->getDims()[1]; // + if (mean->getDims() != dims || var->getDims() != dims || + scale->getDims() != dims || bias->getDims() != dims) + return {}; + return {{input->getDims()}}; +} + +vector BatchNormObj::inferDataType(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 5); + auto index = inputs[1]; + IT_ASSERT(inputs[1]->getDType() == DataType::Float32); + IT_ASSERT(inputs[2]->getDType() == DataType::Float32); + IT_ASSERT(inputs[3]->getDType() == DataType::Float32); + IT_ASSERT(inputs[4]->getDType() == DataType::Float32); + return {inputs[0]->getDType()}; +} + +std::string BatchNormObj::toString() const { + std::ostringstream os; + os << "BatchNorm[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "momentum=" << momentum << ","; + os << "eps=" << eps << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "mean=" << inputs[1]->getGuid() << ","; + os << "var=" << inputs[2]->getGuid() << ","; + os << "scale=" << inputs[3]->getGuid() << ","; + os << "bias=" << inputs[4]->getGuid() << ","; + os << "output="; + for (auto output : outputs) + os << output->getGuid() << ","; + return os.str(); +} + +// need eps and momentum? +vector BatchNormObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + return ret; +} + +// need eps and momentum? +vector BatchNormObj::getOpAttrVector() const { + return {enum_to_underlying(type)}; +} + +} // namespace infini diff --git a/src/operators/concat.cc b/src/operators/concat.cc index 98a5527a..e98497a1 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -55,4 +55,4 @@ vector ConcatObj::getOpAttrVector() const { return {enum_to_underlying(type), dim}; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/conv.cc b/src/operators/conv.cc index e8571f06..240b80fd 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -180,4 +180,4 @@ void ConvTransposed2dObj::setAuxilaryAttributes(PaddingMode mode) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/element_wise.cc b/src/operators/element_wise.cc index ae87758d..bb13586a 100644 --- a/src/operators/element_wise.cc +++ b/src/operators/element_wise.cc @@ -54,4 +54,4 @@ vector ElementWiseObj::getOpAttrVector() const { return {enum_to_underlying(type)}; } -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/operators/extend.cc b/src/operators/extend.cc index 55ef9021..332a76ff 100644 --- a/src/operators/extend.cc +++ b/src/operators/extend.cc @@ -38,4 +38,4 @@ vector ExtendObj::getOpAttrVector() const { return {enum_to_underlying(type), dim, num}; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 225db6fd..95e88126 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -82,4 +82,4 @@ vector GatherObj::getOpAttrVector() const { return {enum_to_underlying(type), axis}; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index db109192..4d5debdb 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -46,4 +46,4 @@ vector MatmulObj::getOpAttrVector() const { return {enum_to_underlying(type), transA, transB, enum_to_underlying(act)}; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/membound.cc b/src/operators/membound.cc index 27d7bdef..dc269742 100644 --- a/src/operators/membound.cc +++ b/src/operators/membound.cc @@ -56,4 +56,4 @@ HashType MemBoundObj::getHash() const { return nnet::HashVisitor().dispatch(expr); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/pooling.cc b/src/operators/pooling.cc index 6e87cc94..ad4d1099 100644 --- a/src/operators/pooling.cc +++ b/src/operators/pooling.cc @@ -49,4 +49,4 @@ vector PoolingObj::getOpAttrVector() const { return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw}; } -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/src/operators/reshape.cc b/src/operators/reshape.cc index c4f1b710..fb15681e 100644 --- a/src/operators/reshape.cc +++ b/src/operators/reshape.cc @@ -101,4 +101,4 @@ vector IdentityObj::getWorkloadVector() const { vector IdentityObj::getOpAttrVector() const { return {enum_to_underlying(type)}; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/split.cc b/src/operators/split.cc index f387d4e4..38c0ffbc 100644 --- a/src/operators/split.cc +++ b/src/operators/split.cc @@ -86,4 +86,4 @@ string SplitObj::toString() const { return os.str(); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_GBMM.cc b/test/kernels/cuda/test_cuda_GBMM.cc index 28a21f90..99d96388 100644 --- a/test/kernels/cuda/test_cuda_GBMM.cc +++ b/test/kernels/cuda/test_cuda_GBMM.cc @@ -34,4 +34,4 @@ TEST(CUDA_GBMM, ShapeInference) { cudaRuntime->run(gCuda); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_batch_norm.cc b/test/kernels/cuda/test_cuda_batch_norm.cc new file mode 100644 index 00000000..e3a5ef04 --- /dev/null +++ b/test/kernels/cuda/test_cuda_batch_norm.cc @@ -0,0 +1,54 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/batch_norm.h" +#include "test.h" + +namespace infini { + +TEST(CUDA_BatchNorm, run) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build cpu graph + Graph gCpu = make_ref(cpuRuntime); + auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32); + auto meanCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto varCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto scaleCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + auto biasCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32); + + // Build input data on CPU + gCpu->dataMalloc(); + iCpu->setData(IncrementalGenerator()); + meanCpu->copyData(vector{1, 6, 9}); + varCpu->copyData(vector{4, 1, 9}); + scaleCpu->setData(OneGenerator()); + biasCpu->setData(ZeroGenerator()); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto i = g->cloneTensor(iCpu); + auto mean = g->cloneTensor(meanCpu); + auto var = g->cloneTensor(varCpu); + auto scale = g->cloneTensor(scaleCpu); + auto bias = g->cloneTensor(biasCpu); + auto op = + g->addOp(i, nullptr, mean, var, scale, bias, 0.9, 0); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(vector{ + -0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.333333, 0, 0.333333, 0.666667})); +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 13ccb67a..12d937c9 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -73,4 +73,4 @@ TEST(Concat, Cuda) { 6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_conv.cc b/test/kernels/cuda/test_cuda_conv.cc index 24028d18..2c0a6419 100644 --- a/test/kernels/cuda/test_cuda_conv.cc +++ b/test/kernels/cuda/test_cuda_conv.cc @@ -76,4 +76,4 @@ TEST(cuDNN_Conv, tune) { bool tune = true; cuda->run(gCuda, tune); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc index c8b28aea..5a56ffae 100644 --- a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc +++ b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc @@ -86,4 +86,4 @@ TEST(cuDNN_ConvTransposed, tune) { ASSERT_TRUE(perfData.has_value()); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_element_wise.cc b/test/kernels/cuda/test_cuda_element_wise.cc index 76215b53..b242af48 100644 --- a/test/kernels/cuda/test_cuda_element_wise.cc +++ b/test/kernels/cuda/test_cuda_element_wise.cc @@ -65,4 +65,4 @@ TEST(cuDNN_ElementWise, run) { ExpectOutput{1, 1, 4, 27}); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_extend.cc b/test/kernels/cuda/test_cuda_extend.cc index 75167649..197246c2 100644 --- a/test/kernels/cuda/test_cuda_extend.cc +++ b/test/kernels/cuda/test_cuda_extend.cc @@ -40,4 +40,4 @@ TEST(CUDA_Extend, run) { 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_gather.cc b/test/kernels/cuda/test_cuda_gather.cc index 09e4b28e..9108d719 100644 --- a/test/kernels/cuda/test_cuda_gather.cc +++ b/test/kernels/cuda/test_cuda_gather.cc @@ -241,4 +241,4 @@ TEST(Gather, Cuda) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_matmul.cc b/test/kernels/cuda/test_cuda_matmul.cc index e0372015..26d5e3d2 100644 --- a/test/kernels/cuda/test_cuda_matmul.cc +++ b/test/kernels/cuda/test_cuda_matmul.cc @@ -73,4 +73,4 @@ TEST(cuBLAS_Matmul, tune) { cudaRuntime->run(gCuda, true); } -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/test/kernels/cuda/test_cuda_pad.cc b/test/kernels/cuda/test_cuda_pad.cc index dfe3a188..c0a96f94 100644 --- a/test/kernels/cuda/test_cuda_pad.cc +++ b/test/kernels/cuda/test_cuda_pad.cc @@ -38,4 +38,4 @@ TEST(Pad, Cuda) { 0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_pooling.cc b/test/kernels/cuda/test_cuda_pooling.cc index 077e3894..3f341591 100644 --- a/test/kernels/cuda/test_cuda_pooling.cc +++ b/test/kernels/cuda/test_cuda_pooling.cc @@ -58,4 +58,4 @@ TEST(cuDNN_AvgPool, run) { 20.444444}); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_reshape.cc b/test/kernels/cuda/test_cuda_reshape.cc index 6ce25a97..4dae2430 100644 --- a/test/kernels/cuda/test_cuda_reshape.cc +++ b/test/kernels/cuda/test_cuda_reshape.cc @@ -95,4 +95,4 @@ TEST(CUDA_Identity, run) { // check results on CPU EXPECT_TRUE(ocpu->equalData(icpu)); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_slice.cc b/test/kernels/cuda/test_cuda_slice.cc index 3cd7da6c..0657a559 100644 --- a/test/kernels/cuda/test_cuda_slice.cc +++ b/test/kernels/cuda/test_cuda_slice.cc @@ -36,4 +36,4 @@ TEST(CUDA_Slice, run) { // check results on CPU EXPECT_TRUE(cpuo->equalData(vector{11, 12, 13, 14, 16, 17, 18, 19})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 38e409f0..484d2ad3 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -37,4 +37,4 @@ TEST(Split, Cuda) { 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/kernels/cuda/test_perfengine.cc b/test/kernels/cuda/test_perfengine.cc index 52b9e3b3..6230f0a2 100644 --- a/test/kernels/cuda/test_perfengine.cc +++ b/test/kernels/cuda/test_perfengine.cc @@ -47,4 +47,4 @@ TEST(PerfEngine, save_and_load) { std::cout << j1 << std::endl; EXPECT_TRUE(j0 == j1); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_batch_norm.cc b/test/operators/test_batch_norm.cc new file mode 100644 index 00000000..d904edd1 --- /dev/null +++ b/test/operators/test_batch_norm.cc @@ -0,0 +1,21 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "operators/batch_norm.h" +#include "test.h" + +namespace infini { +TEST(BatchNorm, ShapeInference) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32); + Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32); + Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32); + Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32); + Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32); + auto op = g->addOp(i, nullptr, mean, var, scaler, bias, + 0.9, 1e-5); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2})); + } +} +} // namespace infini diff --git a/test/operators/test_concat.cc b/test/operators/test_concat.cc index 15ef074b..8c5da64b 100644 --- a/test/operators/test_concat.cc +++ b/test/operators/test_concat.cc @@ -14,4 +14,4 @@ TEST(Concat, ShapeInfer) { EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc index 3b2bff4e..059fa9af 100644 --- a/test/operators/test_conv.cc +++ b/test/operators/test_conv.cc @@ -66,4 +66,4 @@ TEST(Conv, NaiveCPU) { EXPECT_TRUE(conv->getOutput()->equalData(ans)); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_conv_transposed_2d.cc b/test/operators/test_conv_transposed_2d.cc index 34a8ba17..9ce8d7d8 100644 --- a/test/operators/test_conv_transposed_2d.cc +++ b/test/operators/test_conv_transposed_2d.cc @@ -34,4 +34,4 @@ TEST(ConvTransposed, ShapeInference) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_element_wise.cc b/test/operators/test_element_wise.cc index 59fe65a3..68fdc7b9 100644 --- a/test/operators/test_element_wise.cc +++ b/test/operators/test_element_wise.cc @@ -19,4 +19,4 @@ TEST(ElementWise, ShapeInference) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_extend.cc b/test/operators/test_extend.cc index c9c079fd..5fbd4d8a 100644 --- a/test/operators/test_extend.cc +++ b/test/operators/test_extend.cc @@ -17,4 +17,4 @@ TEST(Extend, ShapeInference) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_gather.cc b/test/operators/test_gather.cc index 2bc8fc2e..32dbac64 100644 --- a/test/operators/test_gather.cc +++ b/test/operators/test_gather.cc @@ -16,4 +16,4 @@ TEST(Gather, ShapeInference) { auto op = g->addOp(i, index, nullptr, 1); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_matmul.cc b/test/operators/test_matmul.cc index b7b95a78..b17f8a3a 100644 --- a/test/operators/test_matmul.cc +++ b/test/operators/test_matmul.cc @@ -29,4 +29,4 @@ TEST(Matmul, ShapeInference) { } } -}; // namespace infini \ No newline at end of file +}; // namespace infini diff --git a/test/operators/test_pad.cc b/test/operators/test_pad.cc index 23c11afd..3d033927 100644 --- a/test/operators/test_pad.cc +++ b/test/operators/test_pad.cc @@ -22,4 +22,4 @@ TEST(Pad, ShapeInference) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_pooling.cc b/test/operators/test_pooling.cc index e13100c2..97715a63 100644 --- a/test/operators/test_pooling.cc +++ b/test/operators/test_pooling.cc @@ -68,4 +68,4 @@ TEST(AvgPool, NaiveCPU) { EXPECT_LT(perfTime, 5); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_reshape.cc b/test/operators/test_reshape.cc index a944bbae..a0c016c6 100644 --- a/test/operators/test_reshape.cc +++ b/test/operators/test_reshape.cc @@ -36,4 +36,4 @@ TEST(Identity, ShapeInference) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_slice.cc b/test/operators/test_slice.cc index 0a9430fc..a6717b03 100644 --- a/test/operators/test_slice.cc +++ b/test/operators/test_slice.cc @@ -24,4 +24,4 @@ TEST(Slice, ShapeInference) { } } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_split.cc b/test/operators/test_split.cc index d0e76031..6b336497 100644 --- a/test/operators/test_split.cc +++ b/test/operators/test_split.cc @@ -35,4 +35,4 @@ TEST(Split, ShapeInfer) { } } -} // namespace infini \ No newline at end of file +} // namespace infini