From 6ac106cba4e6873736343716658711935f2d9cfa Mon Sep 17 00:00:00 2001 From: Hardy <100662313+wanghailu0717@users.noreply.github.com> Date: Fri, 16 Sep 2022 13:58:57 +0800 Subject: [PATCH] Add activation operators and kernels * add code for activation operation * add code for activation operation on GPU * add test code for activation operation * add code for activation operation * add code for activation on gpu ,use cudnn * add code for activation on GPU use cudnn * Chore: add constants.h and remove comments Co-authored-by: wanghailu Co-authored-by: Liyan Zheng --- include/core/constants.h | 5 + include/core/operator.h | 10 +- include/cuda/cuda_unary.h | 33 ++++++ include/operators/unary.h | 31 ++++++ src/kernels/cpu/unary.cc | 99 ++++++++++++++++++ src/kernels/cuda/unary.cc | 195 +++++++++++++++++++++++++++++++++++ src/kernels/cuda/unary.cu | 94 +++++++++++++++++ src/operators/unary.cc | 35 +++++++ test/operators/test_unary.cc | 50 +++++++++ 9 files changed, 551 insertions(+), 1 deletion(-) create mode 100644 include/core/constants.h create mode 100644 include/cuda/cuda_unary.h create mode 100644 include/operators/unary.h create mode 100644 src/kernels/cpu/unary.cc create mode 100644 src/kernels/cuda/unary.cc create mode 100644 src/kernels/cuda/unary.cu create mode 100644 src/operators/unary.cc create mode 100644 test/operators/test_unary.cc diff --git a/include/core/constants.h b/include/core/constants.h new file mode 100644 index 00000000..655c8989 --- /dev/null +++ b/include/core/constants.h @@ -0,0 +1,5 @@ +#pragma once + +namespace infini { +constexpr double E_CONSTANT = 2.718281828459; +} \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h index 79e880cf..f65aeaa7 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -32,6 +32,10 @@ enum class OpType { BatchNorm = 200, Softmax, Activation, + Relu, + Sigmoid, + Tanh, + Abs, Resize, // MemBound = 300, @@ -75,6 +79,10 @@ class OpRegistry { FOP(BatchNorm); FOP(Softmax); FOP(Activation); + FOP(Relu); + FOP(Sigmoid); + FOP(Tanh); + FOP(Abs); // FOP(MemBound); default: @@ -204,4 +212,4 @@ namespace std { template <> struct hash { size_t operator()(const infini::OpPerfKey &key) const { return key.hash; } }; -} // namespace std \ No newline at end of file +} // namespace std diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h new file mode 100644 index 00000000..c11912dc --- /dev/null +++ b/include/cuda/cuda_unary.h @@ -0,0 +1,33 @@ +#pragma once + +#include "operators/unary.h" + +namespace infini { +void softmax_kernel(float *input, float *output, int num); +void relu_kernel(float *input, float *output, int num); +void sigmoid_kernel(float *input, float *output, int num); +void tanh_kernel(float *input, float *output, int num); +void abs_kernel(float *input, float *output, int num); + +void unary_kernel(const Operator &_op) { + auto op = as(_op); + float *const inputData = (op->getInputs(0)->getRawDataPtr()); + float *const outputData = (op->getOutput()->getRawDataPtr()); + + auto dim = op->getInputs(0)->getDims(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + if (op->getOpType() == OpType::Softmax) + softmax_kernel(inputData, outputData, n * c * h * w); + else if (op->getOpType() == OpType::Relu) + relu_kernel(inputData, outputData, n * c * h * w); + else if (op->getOpType() == OpType::Sigmoid) + sigmoid_kernel(inputData, outputData, n * c * h * w); + else if (op->getOpType() == OpType::Tanh) + tanh_kernel(inputData, outputData, n * c * h * w); + else if (op->getOpType() == OpType::Abs) + abs_kernel(inputData, outputData, n * c * h * w); + else + IT_TODO_HALT(); +} + +}; // namespace infini diff --git a/include/operators/unary.h b/include/operators/unary.h new file mode 100644 index 00000000..e8a9516d --- /dev/null +++ b/include/operators/unary.h @@ -0,0 +1,31 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class UnaryObj : public OperatorObj { + public: + UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +#define DEFINE_UNARY_OBJ(prefix, type) \ + class prefix##Obj : public UnaryObj { \ + public: \ + prefix##Obj(GraphObj *graph, Tensor input, Tensor output) \ + : UnaryObj(type, graph, input, output) {} \ + }; + +DEFINE_UNARY_OBJ(Relu, OpType::Relu) +DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) +DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) +DEFINE_UNARY_OBJ(Softmax, OpType::Softmax) +DEFINE_UNARY_OBJ(Abs, OpType::Abs) +}; // namespace infini diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc new file mode 100644 index 00000000..1faa6fef --- /dev/null +++ b/src/kernels/cpu/unary.cc @@ -0,0 +1,99 @@ +#include "operators/unary.h" +#include "core/constants.h" +#include "core/kernel.h" + +namespace infini { +template class NativeUnary : public Kernel { + virtual T doCompute(T val) const = 0; + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *context) const override { + auto op = as(_op); + T *inptr = op->getInputs(0)->getRawDataPtr(); + T *outptr = op->getOutput()->getRawDataPtr(); + + auto outDim = op->getOutput()->getDims(); + auto n = op->getOutput()->size(); + for (size_t offset = 0; offset < n; offset++) { + outptr[offset] = doCompute(inptr[offset]); + } + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + compute(op, {}, context); + } + + PerfRecord tune(const Operator &op, + const RuntimeObj *context) const override { + PerfRecord perfrcd(timeit([&]() { compute(op, context); })); + return perfrcd; + } +}; + +template class NaiveSoftmax : public Kernel { + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *context) const override { + auto op = as(_op); + T *inptr = op->getInputs(0)->getRawDataPtr(); + T *outptr = op->getOutput()->getRawDataPtr(); + + auto outDim = op->getOutput()->getDims(); + auto n = op->getOutput()->size(); + auto sum = T(0); + for (size_t offset = 0; offset < n; offset++) { + sum += pow(E_CONSTANT, inptr[offset]); + } + for (size_t offset = 0; offset < n; offset++) { + outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum; + } + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + compute(op, {}, context); + } + + PerfRecord tune(const Operator &op, + const RuntimeObj *context) const override { + PerfRecord perfrcd(timeit([&]() { compute(op, context); })); + return perfrcd; + } +}; + +template class NaiveRelu : public NativeUnary { + T doCompute(T val) const override { return std::max(T(0), val); } +}; +template class NaiveSigmoid : public NativeUnary { + T doCompute(T val) const override { + return 1 / (1 + pow(E_CONSTANT, -val)); + } +}; +template class NaiveTanh : public NativeUnary { + T doCompute(T val) const override { + return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) / + (pow(E_CONSTANT, val) + pow(E_CONSTANT, -val)); + } +}; +template class NaiveAbs : public NativeUnary { + T doCompute(T val) const override { return val < 0 ? -val : val; } +}; + +REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32, + NaiveRelu, "reluNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu, + "reluNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32, + NaiveSigmoid, "sigmoidNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32, + NaiveSigmoid, "sigmoidNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32, + NaiveTanh, "tanhNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh, + "tanhNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs, + "absNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs, + "absNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, + NaiveSoftmax, "softmaxNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, + NaiveSoftmax, "softmaxNaive_CPU_float32"); +}; // namespace infini diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc new file mode 100644 index 00000000..18858e3e --- /dev/null +++ b/src/kernels/cuda/unary.cc @@ -0,0 +1,195 @@ +#include "operators/unary.h" +#include "core/kernel.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_unary.h" + +namespace infini { + +class UnaryCuda : public Kernel { + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + unary_kernel(_op); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + compute(_op, {}, _context); + } + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + PerfRecord ret; + auto context = dynamic_cast(_context); + ret.time = timeit([&]() { compute(_op, _context); }, + [&]() { context->sync(); }); + return ret; + } +}; + +class ActivationCudnn : public Kernel { + virtual cudnnActivationMode_t getOpType() const = 0; + virtual tuple getAlphBeta() const { return {1.f, 0.f}; } + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + cudnnTensorDescriptor_t inputDesc, outputDesc; + auto dim = op->getInputs(0)->getDims(); + if (dim.size() != 4) + IT_TODO_HALT(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + + // get inputs + checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + // get outputs + checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + // get op descriptor + cudnnActivationDescriptor_t activationDesc; + checkCudnnError(cudnnCreateActivationDescriptor(&activationDesc)); + checkCudnnError(cudnnSetActivationDescriptor( + activationDesc, getOpType(), CUDNN_NOT_PROPAGATE_NAN, 0.0)); + + auto [alpha, beta] = getAlphBeta(); + cudnnStatus_t stat = cudnnActivationForward( + context->cudnnHandle(), activationDesc, &alpha, inputDesc, + inputData, &beta, outputDesc, outputData); + if (stat != CUDNN_STATUS_SUCCESS) + return; + + // Destories in CUDA does not require sync. But cuDNN does not state + // whether sync is required before destories. + checkCudnnError(cudnnDestroyActivationDescriptor(activationDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc)); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + compute(_op, {}, _context); + } + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + PerfRecord ret; + auto context = dynamic_cast(_context); + ret.time = timeit([&]() { compute(_op, _context); }, + [&]() { context->sync(); }); + return ret; + } +}; + +class SoftmaxCudnn : public Kernel { + virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0; + virtual cudnnSoftmaxMode_t getModeType() const = 0; + virtual tuple getAlphBeta() const { return {1.f, 0.f}; } + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + cudnnTensorDescriptor_t inputDesc, outputDesc; + auto dim = op->getInputs(0)->getDims(); + if (dim.size() != 4) + IT_TODO_HALT(); + int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + + // get inputs + checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + // get outputs + checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + auto [alpha, beta] = getAlphBeta(); + cudnnStatus_t stat = cudnnSoftmaxForward( + context->cudnnHandle(), getAlgorithmType(), getModeType(), &alpha, + inputDesc, inputData, &beta, outputDesc, outputData); + if (stat != CUDNN_STATUS_SUCCESS) + return; + + // Destories in CUDA does not require sync. But cuDNN does not state + // whether sync is required before destories. + checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc)); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + compute(_op, {}, _context); + } + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + PerfRecord ret; + auto context = dynamic_cast(_context); + ret.time = timeit([&]() { compute(_op, _context); }, + [&]() { context->sync(); }); + return ret; + } +}; + +class ReluCudnn : public ActivationCudnn { + cudnnActivationMode_t getOpType() const override { + return CUDNN_ACTIVATION_RELU; + } +}; + +class SigmoidCudnn : public ActivationCudnn { + cudnnActivationMode_t getOpType() const override { + return CUDNN_ACTIVATION_SIGMOID; + } +}; + +class TanhCudnn : public ActivationCudnn { + cudnnActivationMode_t getOpType() const override { + return CUDNN_ACTIVATION_TANH; + } +}; + +class NormalSoftmaxCudnn : public SoftmaxCudnn { + cudnnSoftmaxAlgorithm_t getAlgorithmType() const override { + return CUDNN_SOFTMAX_ACCURATE; + } + cudnnSoftmaxMode_t getModeType() const override { + return CUDNN_SOFTMAX_MODE_INSTANCE; + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, + NormalSoftmaxCudnn, "Softmax_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn, + "Relu_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn, + "Sigmoid_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn, + "Tanh_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, + "Abs_CUDA_Float32"); + +// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda, +// "Softmax_CUDA_Float32"); +// REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, UnaryCuda, +// "Relu_CUDA_Float32"); +// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, UnaryCuda, +// "Sigmoid_CUDA_Float32"); +// REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, UnaryCuda, +// "Tanh_CUDA_Float32"); +// REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, +// "Abs_CUDA_Float32"); +}; // namespace infini diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu new file mode 100644 index 00000000..b81d8c63 --- /dev/null +++ b/src/kernels/cuda/unary.cu @@ -0,0 +1,94 @@ +#include "core/common.h" +#include "core/constants.h" +#include "cuda/cuda_common.h" +#include + +using infini::E_CONSTANT; +constexpr unsigned int num_threads() { return 32 * 4; } +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } + +__global__ void _softmax_kernel1(float *input, float *output, int n) { + float sum = 0.0f; + for (size_t i = 0; i < n; ++i) { + sum += pow(E_CONSTANT, input[i]); + } + *output = sum; +} + +__global__ void _softmax_kernel2(float *input, float *output, int n) { + float sum = *output; + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = pow(E_CONSTANT, input[i]) / sum; + } +} + +__global__ void _relu_kernel(float *input, float *output, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = max(input[i], float(0)); + } +} + +__global__ void _sigmoid_kernel(float *input, float *output, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = 1 / (1 + pow(E_CONSTANT, -input[i])); + } +} + +__global__ void _tanh_kernel(float *input, float *output, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = (pow(E_CONSTANT, input[i]) - pow(E_CONSTANT, -input[i])) / + (pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i])); + } +} + +__global__ void _abs_kernel(float *input, float *output, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = input[i] < 0 ? -input[i] : input[i]; + } +} + +namespace infini { +void softmax_kernel(float *input, float *output, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _softmax_kernel1<<<1, 1>>>(input, output, num); + _softmax_kernel2<<>>(input, output, num); +} +void relu_kernel(float *input, float *output, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _relu_kernel<<>>(input, output, num); +} +void sigmoid_kernel(float *input, float *output, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _sigmoid_kernel<<>>(input, output, num); +} +void tanh_kernel(float *input, float *output, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _tanh_kernel<<>>(input, output, num); +} +void abs_kernel(float *input, float *output, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _abs_kernel<<>>(input, output, num); +} + +}; // namespace infini diff --git a/src/operators/unary.cc b/src/operators/unary.cc new file mode 100644 index 00000000..cecf0e33 --- /dev/null +++ b/src/operators/unary.cc @@ -0,0 +1,35 @@ +#include "operators/unary.h" + +namespace infini { +UnaryObj::UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output) + : OperatorObj(type, {input}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> UnaryObj::inferShape(const TensorVec &inputs) const { + const auto A = inputs[0]; + return {{A->getDims()}}; +} + +std::string UnaryObj::toString() const { + std::ostringstream os; + os << OpRegistry::getOpName(type) << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector UnaryObj::getWorkloadVector() const { + vector ret{enum_to_underlying(type)}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector UnaryObj::getOpAttrVector() const { + return {enum_to_underlying(type)}; +} + +}; // namespace infini diff --git a/test/operators/test_unary.cc b/test/operators/test_unary.cc new file mode 100644 index 00000000..3934692f --- /dev/null +++ b/test/operators/test_unary.cc @@ -0,0 +1,50 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/unary.h" + +#include "test.h" + +namespace infini { + +template +void testUnary(const std::function &generator, + const Shape &shape) { + // Runtime + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = make_ref(shape, DataType::Float32, cpuRuntime); + inputCpu->dataMalloc(); + inputCpu->setData(generator); + + // GPU + Graph cudaGraph = make_ref(cudaRuntime); + auto inputGpu = cudaGraph->cloneTensor(inputCpu); + auto gpuOp = cudaGraph->addOp(inputGpu, nullptr); + cudaGraph->dataMalloc(); + cudaRuntime->run(cudaGraph); + auto outputGpu = gpuOp->getOutput(); + auto outputGpu2Cpu = outputGpu->clone(cpuRuntime); + // CPU + Graph cpuGraph = make_ref(cpuRuntime); + auto cpuOp = cpuGraph->addOp(inputCpu, nullptr); + cpuGraph->dataMalloc(); + cpuRuntime->run(cpuGraph); + auto outputCpu = cpuOp->getOutput(); + // Check + EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); +} + +TEST(Unary, CuDNN) { + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +} // namespace infini