From 2f8f706f1c264e8bc539c7497088607bef78c705 Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Wed, 21 Sep 2022 12:28:00 +0800 Subject: [PATCH] Fix CMake USE_CUDA (#36) * Fix: build lib without cuda * Chore: rename GBMM and G2BMM files * Fix: seperate CUDA tests from operator tests * Fix: CMake CMP0104 * Chore: fix typo * Chore: remove unused headers Co-authored-by: Liyan Zheng --- CMakeLists.txt | 14 ++- include/cuda/cuda_common.h | 1 + .../{custom_ops.cuh => cuda/gbmm_g2bmm.cuh} | 7 +- include/{custom_ops.h => cuda/gbmm_g2bmm.h} | 5 +- src/core/runtime.cc | 4 - src/kernels/cuda/G2BMM.cc | 2 +- src/kernels/cuda/GBMM.cc | 2 +- .../cuda/gbmm_g2bmm.cu} | 4 +- src/nnet/derivator.cc | 4 +- src/operators/G2BMM.cc | 1 - src/operators/GBMM.cc | 1 - .../cuda/test_cuda_G2BMM.cc} | 2 +- .../cuda/test_cuda_GBMM.cc} | 2 +- test/kernels/cuda/test_cuda_conv.cc | 79 ++++++++++++++ .../cuda/test_cuda_conv_transposed_2d.cc | 89 ++++++++++++++++ test/kernels/cuda/test_cuda_element_wise.cc | 68 ++++++++++++ test/kernels/cuda/test_cuda_matmul.cc | 76 +++++++++++++ test/kernels/cuda/test_cuda_pooling.cc | 61 +++++++++++ .../cuda/test_cuda_unary.cc} | 2 +- test/operators/test_conv.cc | 69 ------------ test/operators/test_conv_transposed_2d.cc | 78 -------------- test/operators/test_element_wise.cc | 100 ------------------ test/operators/test_matmul.cc | 63 ----------- test/operators/test_pooling.cc | 51 --------- 24 files changed, 400 insertions(+), 385 deletions(-) rename include/{custom_ops.cuh => cuda/gbmm_g2bmm.cuh} (99%) rename include/{custom_ops.h => cuda/gbmm_g2bmm.h} (82%) rename src/{custom_ops.cu => kernels/cuda/gbmm_g2bmm.cu} (86%) rename test/{operators/test_G2BMM.cc => kernels/cuda/test_cuda_G2BMM.cc} (97%) rename test/{operators/test_GBMM.cc => kernels/cuda/test_cuda_GBMM.cc} (95%) create mode 100644 test/kernels/cuda/test_cuda_conv.cc create mode 100644 test/kernels/cuda/test_cuda_conv_transposed_2d.cc create mode 100644 test/kernels/cuda/test_cuda_element_wise.cc create mode 100644 test/kernels/cuda/test_cuda_matmul.cc create mode 100644 test/kernels/cuda/test_cuda_pooling.cc rename test/{operators/test_unary.cc => kernels/cuda/test_cuda_unary.cc} (98%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 30ef1bb1..4fdf5e30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,8 +73,15 @@ if(BUILD_TEST) include_directories(3rd-party/googletest/googletest/include) endif() -file(GLOB_RECURSE SRC src/*.cc src/*.cu) +# Source files +file(GLOB_RECURSE SRC src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc) +if(USE_CUDA) + file(GLOB_RECURSE SRC_CUDA src/cuda/*.cc src/cuda/*.cu src/kernels/cuda/*.cc src/kernels/cuda/*.cu) + list (APPEND SRC ${SRC_CUDA}) +endif() + +# Libraries add_library(InfiniTensor SHARED ${SRC}) if(USE_PROTOBUF) target_link_libraries(InfiniTensor tensor_proto) @@ -93,6 +100,8 @@ if(USE_CUDA) set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} CACHE STRING "Set cuda host compiler path") + # CMP0104 requires CUDA_ARCHITECTURES + set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80") enable_language(CUDA) # TODO: find_package seems unnecessary for CMake >= 3.8 find_package(CUDA REQUIRED) @@ -123,6 +132,9 @@ if(BUILD_TEST) if(BUILD_TEST_CORE) build_test(test/core/*.cc) build_test(test/operators/*.cc) + if (USE_CUDA) + build_test(test/kernels/cuda/*.cc) + endif() endif() if(BUILD_TEST_PET) build_test(test/pet/*.cc) diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index fdf7b1de..9b4680e2 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -2,6 +2,7 @@ #include "core/common.h" #include #include +#include #include #include diff --git a/include/custom_ops.cuh b/include/cuda/gbmm_g2bmm.cuh similarity index 99% rename from include/custom_ops.cuh rename to include/cuda/gbmm_g2bmm.cuh index d3d00e94..626f7b4c 100644 --- a/include/custom_ops.cuh +++ b/include/cuda/gbmm_g2bmm.cuh @@ -1,5 +1,4 @@ -#ifndef CUSTOM_OPS_CUH -#define CUSTOM_OPS_CUH +#pragma once #include @@ -5797,6 +5796,4 @@ inline void sgbmml(float *__restrict__ q, float *__restrict__ k, } } -} // namespace infini - -#endif // CUSTOM_OPS_CUH +} // namespace infini \ No newline at end of file diff --git a/include/custom_ops.h b/include/cuda/gbmm_g2bmm.h similarity index 82% rename from include/custom_ops.h rename to include/cuda/gbmm_g2bmm.h index fa943a33..00132e79 100644 --- a/include/custom_ops.h +++ b/include/cuda/gbmm_g2bmm.h @@ -1,5 +1,4 @@ -#ifndef CUSTOM_OPS_H -#define CUSTOM_OPS_H +#pragma once namespace infini { @@ -10,5 +9,3 @@ void _sgbmml(float *__restrict__ q, float *__restrict__ k, float *__restrict__ y, int bs, int n, int m, int w, int d); } // namespace infini - -#endif // CUSTOM_OPS_H diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 449c997e..faf3d0e6 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -4,10 +4,6 @@ #include "core/perf_engine.h" #include #include -#include -#include -#include -#include namespace infini { diff --git a/src/kernels/cuda/G2BMM.cc b/src/kernels/cuda/G2BMM.cc index 160dfe41..cb69f76a 100644 --- a/src/kernels/cuda/G2BMM.cc +++ b/src/kernels/cuda/G2BMM.cc @@ -1,7 +1,7 @@ #include "operators/G2BMM.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" -#include "custom_ops.h" +#include "cuda/gbmm_g2bmm.h" #include #include #include diff --git a/src/kernels/cuda/GBMM.cc b/src/kernels/cuda/GBMM.cc index 0b02a1f9..303e4553 100644 --- a/src/kernels/cuda/GBMM.cc +++ b/src/kernels/cuda/GBMM.cc @@ -1,7 +1,7 @@ #include "operators/GBMM.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" -#include "custom_ops.h" +#include "cuda/gbmm_g2bmm.h" #include #include #include diff --git a/src/custom_ops.cu b/src/kernels/cuda/gbmm_g2bmm.cu similarity index 86% rename from src/custom_ops.cu rename to src/kernels/cuda/gbmm_g2bmm.cu index f5ac1df3..0002e021 100644 --- a/src/custom_ops.cu +++ b/src/kernels/cuda/gbmm_g2bmm.cu @@ -1,5 +1,5 @@ -#include "custom_ops.cuh" -#include "custom_ops.h" +#include "cuda/gbmm_g2bmm.cuh" +#include "cuda/gbmm_g2bmm.h" namespace infini { diff --git a/src/nnet/derivator.cc b/src/nnet/derivator.cc index 11165934..7a2e61d4 100644 --- a/src/nnet/derivator.cc +++ b/src/nnet/derivator.cc @@ -20,10 +20,12 @@ #include "nnet/Visitor/HashVisitor.h" #include "nnet/Visitor/MergeMemboundMutator.h" #include "nnet/Visitor/Serializer.h" -#include "nnet/test.h" namespace nnet { +// avoid dependence of "nnet/test.h" +bool checkExprsEquvivalence(VecExpr exprs); + class SaveStateGuard { Derivator &derivator; diff --git a/src/operators/G2BMM.cc b/src/operators/G2BMM.cc index a6f8606d..0cc22d8f 100644 --- a/src/operators/G2BMM.cc +++ b/src/operators/G2BMM.cc @@ -1,5 +1,4 @@ #include "operators/G2BMM.h" -#include "custom_ops.h" namespace infini { diff --git a/src/operators/GBMM.cc b/src/operators/GBMM.cc index 4d734593..227c48f7 100644 --- a/src/operators/GBMM.cc +++ b/src/operators/GBMM.cc @@ -1,5 +1,4 @@ #include "operators/GBMM.h" -#include "custom_ops.h" namespace infini { diff --git a/test/operators/test_G2BMM.cc b/test/kernels/cuda/test_cuda_G2BMM.cc similarity index 97% rename from test/operators/test_G2BMM.cc rename to test/kernels/cuda/test_cuda_G2BMM.cc index ec8aae00..24ee5a4d 100644 --- a/test/operators/test_G2BMM.cc +++ b/test/kernels/cuda/test_cuda_G2BMM.cc @@ -10,7 +10,7 @@ namespace infini { using ExpectOutput = vector; -TEST(G2BMM, ShapeInference) { +TEST(CUDA_G2BMM, ShapeInference) { const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4; const int hidden = featlen, hiddenPerHead = hidden / heads; auto cpuRuntime = CpuRuntimeObj::getInstance(); diff --git a/test/operators/test_GBMM.cc b/test/kernels/cuda/test_cuda_GBMM.cc similarity index 95% rename from test/operators/test_GBMM.cc rename to test/kernels/cuda/test_cuda_GBMM.cc index f31fdd55..28a21f90 100644 --- a/test/operators/test_GBMM.cc +++ b/test/kernels/cuda/test_cuda_GBMM.cc @@ -9,7 +9,7 @@ namespace infini { using ExpectOutput = vector; -TEST(GBMM, ShapeInference) { +TEST(CUDA_GBMM, ShapeInference) { const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4; const int hidden = featlen, hiddenPerHead = hidden / heads; auto cpuRuntime = CpuRuntimeObj::getInstance(); diff --git a/test/kernels/cuda/test_cuda_conv.cc b/test/kernels/cuda/test_cuda_conv.cc new file mode 100644 index 00000000..24028d18 --- /dev/null +++ b/test/kernels/cuda/test_cuda_conv.cc @@ -0,0 +1,79 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +void testConvCudnn( + const std::function &generator, + vector ansVec) { + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + gCuda->print(); +} + +TEST(cuDNN_Conv, run) { + testConvCudnn(OneGenerator(), + vector{12, 12, 18, 18, 12, 12, 18, 18}); + testConvCudnn( + IncrementalGenerator(), + vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); +} + +TEST(cuDNN_Conv, tune) { + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); +} +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_conv_transposed_2d.cc b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc new file mode 100644 index 00000000..c8b28aea --- /dev/null +++ b/test/kernels/cuda/test_cuda_conv_transposed_2d.cc @@ -0,0 +1,89 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +void testConvTransposedCudnn( + const std::function &generator, + vector ansVec) { + const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; + const int stride = 1, padding = 0, dilation = 1; + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr, + padding, padding, stride, + stride, dilation, dilation); + gCuda->dataMalloc(); + // Execute on CUDA + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); +} + +TEST(cuDNN_ConvTransposed, run) { + testConvTransposedCudnn(IncrementalGenerator(), + vector{0., 0., 1., 2., 3., 0., 6., + 12., 18., 16., 8., 30., 36., 42., + 32., 16., 54., 60., 66., 48., 24., + 62., 67., 72., 45.}); +} + +TEST(cuDNN_ConvTransposed, tune) { + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 448, 2, 2}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({448, 256, 4, 4}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); + // print a tensor/operator/graph by print() + gCuda->print(); + // check record + auto kernelAttrs = + KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32}; + auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; + std::optional perfData = + PerfEngine::getInstance().getPerfData(perfKey); + ASSERT_TRUE(perfData.has_value()); +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_element_wise.cc b/test/kernels/cuda/test_cuda_element_wise.cc new file mode 100644 index 00000000..76215b53 --- /dev/null +++ b/test/kernels/cuda/test_cuda_element_wise.cc @@ -0,0 +1,68 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +template +void testElementWiseCudnn( + const std::function &generator, + const Shape &shape, const ExpectOutput &ansVec) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor acpu = make_ref(shape, DataType::Float32, cpuRuntime); + acpu->dataMalloc(); + acpu->setData(generator); + + Tensor bcpu = make_ref(shape, DataType::Float32, cpuRuntime); + bcpu->dataMalloc(); + bcpu->setData(generator); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto a = g->cloneTensor(acpu); + auto b = g->cloneTensor(bcpu); + auto op = g->addOp(a, b, nullptr); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto c = op->getOutput(); + auto ccpu = c->clone(cpuRuntime); + // cudaPrintTensor(c); + // check results on CPU + EXPECT_TRUE(ccpu->equalData(ansVec)); +} + +TEST(cuDNN_ElementWise, run) { + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + testElementWiseCudnn( + IncrementalGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); + + testElementWiseCudnn( + OneGenerator(), Shape{1, 2, 2, 3}, + ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + testElementWiseCudnn(IncrementalGenerator(), Shape{1, 2, 2, 1}, + ExpectOutput{1, 1, 4, 27}); +} + +} // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_matmul.cc b/test/kernels/cuda/test_cuda_matmul.cc new file mode 100644 index 00000000..e0372015 --- /dev/null +++ b/test/kernels/cuda/test_cuda_matmul.cc @@ -0,0 +1,76 @@ + +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/matmul.h" + +#include "test.h" + +namespace infini { +using ExpectOutput = vector; + +void testMatmulCuda( + const std::function &generatorA, + const std::function &generatorB, + bool transA, bool transB, const Shape &shapeA, const Shape &shapeB, + const ExpectOutput &ansVec) { + auto cpuRuntime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + auto ACpu = gCpu->addTensor(shapeA, DataType::Float32); + auto BCpu = gCpu->addTensor(shapeB, DataType::Float32); + gCpu->dataMalloc(); + ACpu->setData(generatorA); + BCpu->setData(generatorB); + + auto cudaRuntime = make_ref(); + auto gCuda = make_ref(cudaRuntime); + auto ACuda = gCuda->cloneTensor(ACpu); + auto BCuda = gCuda->cloneTensor(BCpu); + auto matmul = + gCuda->addOp(ACuda, BCuda, nullptr, transA, transB); + + // allocate CUDA memory + gCuda->dataMalloc(); + cudaRuntime->run(gCuda); + + auto CCpu = gCpu->cloneTensor(matmul->getOutput()); + // CCpu->printData(); + // check results on CPU + EXPECT_TRUE(CCpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + // gCuda->print(); +} + +TEST(cuBLAS_Matmul, run) { + testMatmulCuda(IncrementalGenerator(), OneGenerator(), false, false, + Shape{1, 3, 5}, Shape{1, 5, 2}, + ExpectOutput{10, 10, 35, 35, 60, 60}); + testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false, + Shape{2, 3, 4}, Shape{2, 3, 2}, + ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424, + 475, 448, 502, 472, 529}); +} + +TEST(cuBLAS_Matmul, tune) { + auto cpuRuntime = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpuRuntime); + auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32); + auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32); + gCpu->dataMalloc(); + ACpu->setData(IncrementalGenerator()); + BCpu->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + auto gCuda = make_ref(cudaRuntime); + auto ACuda = gCuda->cloneTensor(ACpu); + auto BCuda = gCuda->cloneTensor(BCpu); + auto matmul = gCuda->addOp(ACuda, BCuda, nullptr); + + // allocate CUDA memory + gCuda->dataMalloc(); + cudaRuntime->run(gCuda, true); +} + +}; // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_cuda_pooling.cc b/test/kernels/cuda/test_cuda_pooling.cc new file mode 100644 index 00000000..077e3894 --- /dev/null +++ b/test/kernels/cuda/test_cuda_pooling.cc @@ -0,0 +1,61 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/pooling.h" +#include "test.h" + +namespace infini { +using KDPS = vector; +using ExpectOutput = vector; + +template +void testPoolCudnn( + const std::function &generator, + const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) { + EXPECT_TRUE(kdps.size() == 8); + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor i0cpu = make_ref(shape, DataType::Float32, cpuRuntime); + i0cpu->dataMalloc(); + i0cpu->setData(generator); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto i0 = g->cloneTensor(i0cpu); + auto pool = g->addOp(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3], + kdps[4], kdps[5], kdps[6], kdps[7]); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o0 = pool->getOutput(); + auto cpuo0 = o0->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(cpuo0->equalData(ansVec)); +} + +TEST(cuDNN_MaxPool, run) { + testPoolCudnn(IncrementalGenerator(), Shape{1, 2, 5, 5}, + KDPS{3, 3, 1, 1, 1, 1, 2, 2}, + ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31, + 33, 34, 41, 43, 44, 46, 48, 49}); +} + +TEST(cuDNN_AvgPool, run) { + testPoolCudnn( + IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2}, + ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000, + 8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778, + 23.666667, 37.0000, 25.666667, 19.111111, 29.666667, + 20.444444}); +} + +} // namespace infini \ No newline at end of file diff --git a/test/operators/test_unary.cc b/test/kernels/cuda/test_cuda_unary.cc similarity index 98% rename from test/operators/test_unary.cc rename to test/kernels/cuda/test_cuda_unary.cc index 3934692f..8b463121 100644 --- a/test/operators/test_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -39,7 +39,7 @@ void testUnary(const std::function &generator, EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu)); } -TEST(Unary, CuDNN) { +TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc index 060571cb..3b2bff4e 100644 --- a/test/operators/test_conv.cc +++ b/test/operators/test_conv.cc @@ -1,8 +1,6 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" -#include "cuda/cuda_runtime.h" -#include "cuda/cuda_utility.h" #include "operators/conv.h" #include "test.h" @@ -68,71 +66,4 @@ TEST(Conv, NaiveCPU) { EXPECT_TRUE(conv->getOutput()->equalData(ans)); } -void testConvCudnn( - const std::function &generator, - vector ansVec) { - // Construct Runtime and graph for CPU and CUDA - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton - Graph gCpu = make_ref(cpu); - Runtime cuda = make_ref(); - Graph gCuda = make_ref(cuda); - // Set input data on CPU in a CPU Graph - Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float32); - Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); - // Malloc data for all tensors in a graph. Do we need implicit allocation? - gCpu->dataMalloc(); - i0Cpu->setData(generator); - w0Cpu->setData(generator); - - // Copy input tensors from CPU to CUDA - Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); - Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); - // Build CUDA graph - auto conv = - gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2); - // allocate CUDA memory - gCuda->dataMalloc(); - // Execute on CUDA - cuda->run(gCuda); - // copy output from CUDA to CPU - auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); - // check results on CPU - EXPECT_TRUE(o0Cpu->equalData(ansVec)); - // print a tensor/operator/graph by print() - gCuda->print(); -} - -TEST(Conv, cuDNN) { - testConvCudnn(OneGenerator(), - vector{12, 12, 18, 18, 12, 12, 18, 18}); - testConvCudnn( - IncrementalGenerator(), - vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); -} - -TEST(Conv, tune) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton - Graph gCpu = make_ref(cpu); - Runtime cuda = make_ref(); - Graph gCuda = make_ref(cuda); - // Set input data on CPU in a CPU Graph - Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32); - Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); - // Malloc data for all tensors in a graph. Do we need implicit allocation? - gCpu->dataMalloc(); - i0Cpu->setData(IncrementalGenerator()); - w0Cpu->setData(IncrementalGenerator()); - - // Copy input tensors from CPU to CUDA - Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); - Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); - // Build CUDA graph - auto conv = - gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); - // allocate CUDA memory - gCuda->dataMalloc(); - // Execute on CUDA - bool tune = true; - cuda->run(gCuda, tune); -} } // namespace infini \ No newline at end of file diff --git a/test/operators/test_conv_transposed_2d.cc b/test/operators/test_conv_transposed_2d.cc index 1806e3a2..34a8ba17 100644 --- a/test/operators/test_conv_transposed_2d.cc +++ b/test/operators/test_conv_transposed_2d.cc @@ -2,8 +2,6 @@ #include "core/kernel.h" #include "core/perf_engine.h" #include "core/runtime.h" -#include "cuda/cuda_runtime.h" -#include "cuda/cuda_utility.h" #include "operators/conv.h" #include "test.h" @@ -36,80 +34,4 @@ TEST(ConvTransposed, ShapeInference) { } } -void testConvTransposedCudnn( - const std::function &generator, - vector ansVec) { - const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; - const int stride = 1, padding = 0, dilation = 1; - // Construct Runtime and graph for CPU and CUDA - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton - Graph gCpu = make_ref(cpu); - Runtime cuda = make_ref(); - Graph gCuda = make_ref(cuda); - // Set input data on CPU in a CPU Graph - Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32); - Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32); - // Malloc data for all tensors in a graph. Do we need implicit allocation? - gCpu->dataMalloc(); - i0Cpu->setData(generator); - w0Cpu->setData(generator); - - // Copy input tensors from CPU to CUDA - Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); - Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); - // Build CUDA graph - auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr, - padding, padding, stride, - stride, dilation, dilation); - gCuda->dataMalloc(); - // Execute on CUDA - cuda->run(gCuda); - // copy output from CUDA to CPU - auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); - // check results on CPU - EXPECT_TRUE(o0Cpu->equalData(ansVec)); -} - -TEST(ConvTransposed, cuDNN) { - testConvTransposedCudnn(IncrementalGenerator(), - vector{0., 0., 1., 2., 3., 0., 6., - 12., 18., 16., 8., 30., 36., 42., - 32., 16., 54., 60., 66., 48., 24., - 62., 67., 72., 45.}); -} - -TEST(ConvTransposed, tune) { - Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton - Graph gCpu = make_ref(cpu); - Runtime cuda = make_ref(); - Graph gCuda = make_ref(cuda); - // Set input data on CPU in a CPU Graph - Tensor i0Cpu = gCpu->addTensor({1, 448, 2, 2}, DataType::Float32); - Tensor w0Cpu = gCpu->addTensor({448, 256, 4, 4}, DataType::Float32); - // Malloc data for all tensors in a graph. Do we need implicit allocation? - gCpu->dataMalloc(); - i0Cpu->setData(IncrementalGenerator()); - w0Cpu->setData(IncrementalGenerator()); - - // Copy input tensors from CPU to CUDA - Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); - Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); - // Build CUDA graph - auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr); - // allocate CUDA memory - gCuda->dataMalloc(); - // Execute on CUDA - bool tune = true; - cuda->run(gCuda, tune); - // print a tensor/operator/graph by print() - gCuda->print(); - // check record - auto kernelAttrs = - KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32}; - auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; - std::optional perfData = - PerfEngine::getInstance().getPerfData(perfKey); - ASSERT_TRUE(perfData.has_value()); -} - } // namespace infini \ No newline at end of file diff --git a/test/operators/test_element_wise.cc b/test/operators/test_element_wise.cc index 1bdc5f4e..59fe65a3 100644 --- a/test/operators/test_element_wise.cc +++ b/test/operators/test_element_wise.cc @@ -1,8 +1,6 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" -#include "cuda/cuda_runtime.h" -#include "cuda/cuda_utility.h" #include "operators/element_wise.h" #include "test.h" @@ -20,103 +18,5 @@ TEST(ElementWise, ShapeInference) { EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3, 4})); } } -/* -template -void test_element_wise( - const std::function &generator, - const vector &ans) { - Runtime runtime = CpuRuntimeObj::getInstance(); - Graph g = make_ref(runtime); - Tensor i0 = g->addTensor({1, 3, 2, 2}, DataType::UInt32); - Tensor i1 = g->addTensor({2, 3, 1, 2}, DataType::UInt32); - auto op = g->addOp(i0, i1, nullptr); - - g->dataMalloc(); - i0->setData(generator); - i1->setData(generator); - runtime->run(g, true, true); - // check answer - EXPECT_TRUE(op->getOutput()->equalData(ans)); -} - -TEST(ElementWise, NaiveCPU) { - test_element_wise(IncrementalGenerator(), - vector{0, 2, 2, 4, 6, 8, 8, 10, - 12, 14, 14, 16, 6, 8, 8, 10, - 12, 14, 14, 16, 18, 20, 20, 22}); - test_element_wise( - IncrementalGenerator(), - vector{0, 0, 2, 2, - 2, 2, 4, 4, - 4, 4, 6, 6, - 4294967290, 4294967290, 4294967292, 4294967292, - 4294967292, 4294967292, 4294967294, 4294967294, - 4294967294, 4294967294, 0, 0}); - test_element_wise( - IncrementalGenerator(), - vector{0, 1, 0, 3, 8, 15, 12, 21, 32, 45, 40, 55, - 0, 7, 12, 21, 32, 45, 48, 63, 80, 99, 100, 121}); - test_element_wise(OneGenerator(), - vector{ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - }); -} -*/ - -template -void testElementWiseCudnn( - const std::function &generator, - const Shape &shape, const ExpectOutput &ansVec) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); - - // Build input data on CPU - Tensor acpu = make_ref(shape, DataType::Float32, cpuRuntime); - acpu->dataMalloc(); - acpu->setData(generator); - - Tensor bcpu = make_ref(shape, DataType::Float32, cpuRuntime); - bcpu->dataMalloc(); - bcpu->setData(generator); - - // Build CUDA graph - Graph g = make_ref(cudaRuntime); - auto a = g->cloneTensor(acpu); - auto b = g->cloneTensor(bcpu); - auto op = g->addOp(a, b, nullptr); - - // allocate CUDA memory - g->dataMalloc(); - - // Execute on CUDA - cudaRuntime->run(g); - - // clone CUDA output to CPU - auto c = op->getOutput(); - auto ccpu = c->clone(cpuRuntime); - // cudaPrintTensor(c); - // check results on CPU - EXPECT_TRUE(ccpu->equalData(ansVec)); -} - -TEST(ElementWise, CuDNN) { - testElementWiseCudnn( - IncrementalGenerator(), Shape{1, 2, 2, 3}, - ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - testElementWiseCudnn( - IncrementalGenerator(), Shape{1, 2, 2, 3}, - ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - testElementWiseCudnn( - IncrementalGenerator(), Shape{1, 2, 2, 3}, - ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121}); - - testElementWiseCudnn( - OneGenerator(), Shape{1, 2, 2, 3}, - ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - - testElementWiseCudnn(IncrementalGenerator(), Shape{1, 2, 2, 1}, - ExpectOutput{1, 1, 4, 27}); -} } // namespace infini \ No newline at end of file diff --git a/test/operators/test_matmul.cc b/test/operators/test_matmul.cc index 97df77c2..b7b95a78 100644 --- a/test/operators/test_matmul.cc +++ b/test/operators/test_matmul.cc @@ -2,8 +2,6 @@ #include "core/graph.h" #include "core/kernel.h" #include "core/runtime.h" -#include "cuda/cuda_runtime.h" -#include "cuda/cuda_utility.h" #include "operators/matmul.h" #include "test.h" @@ -30,66 +28,5 @@ TEST(Matmul, ShapeInference) { EXPECT_EQ(C->getDims(), (Shape{3, 4, 2})); } } -void testMatmulCuda( - const std::function &generatorA, - const std::function &generatorB, - bool transA, bool transB, const Shape &shapeA, const Shape &shapeB, - const ExpectOutput &ansVec) { - auto cpuRuntime = CpuRuntimeObj::getInstance(); - Graph gCpu = make_ref(cpuRuntime); - auto ACpu = gCpu->addTensor(shapeA, DataType::Float32); - auto BCpu = gCpu->addTensor(shapeB, DataType::Float32); - gCpu->dataMalloc(); - ACpu->setData(generatorA); - BCpu->setData(generatorB); - - auto cudaRuntime = make_ref(); - auto gCuda = make_ref(cudaRuntime); - auto ACuda = gCuda->cloneTensor(ACpu); - auto BCuda = gCuda->cloneTensor(BCpu); - auto matmul = - gCuda->addOp(ACuda, BCuda, nullptr, transA, transB); - - // allocate CUDA memory - gCuda->dataMalloc(); - cudaRuntime->run(gCuda); - - auto CCpu = gCpu->cloneTensor(matmul->getOutput()); - // CCpu->printData(); - // check results on CPU - EXPECT_TRUE(CCpu->equalData(ansVec)); - // print a tensor/operator/graph by print() - // gCuda->print(); -} - -TEST(Matmul, cuBlas) { - testMatmulCuda(IncrementalGenerator(), OneGenerator(), false, false, - Shape{1, 3, 5}, Shape{1, 5, 2}, - ExpectOutput{10, 10, 35, 35, 60, 60}); - testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false, - Shape{2, 3, 4}, Shape{2, 3, 2}, - ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424, - 475, 448, 502, 472, 529}); -} - -TEST(Matmul, tune) { - auto cpuRuntime = CpuRuntimeObj::getInstance(); - Graph gCpu = make_ref(cpuRuntime); - auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32); - auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32); - gCpu->dataMalloc(); - ACpu->setData(IncrementalGenerator()); - BCpu->setData(IncrementalGenerator()); - - auto cudaRuntime = make_ref(); - auto gCuda = make_ref(cudaRuntime); - auto ACuda = gCuda->cloneTensor(ACpu); - auto BCuda = gCuda->cloneTensor(BCpu); - auto matmul = gCuda->addOp(ACuda, BCuda, nullptr); - - // allocate CUDA memory - gCuda->dataMalloc(); - cudaRuntime->run(gCuda, true); -} }; // namespace infini \ No newline at end of file diff --git a/test/operators/test_pooling.cc b/test/operators/test_pooling.cc index 8d1c0c00..e13100c2 100644 --- a/test/operators/test_pooling.cc +++ b/test/operators/test_pooling.cc @@ -1,7 +1,5 @@ #include "core/graph.h" #include "core/runtime.h" -#include "cuda/cuda_runtime.h" -#include "cuda/cuda_utility.h" #include "operators/pooling.h" #include "test.h" @@ -70,53 +68,4 @@ TEST(AvgPool, NaiveCPU) { EXPECT_LT(perfTime, 5); } -template -void testPoolCudnn( - const std::function &generator, - const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) { - EXPECT_TRUE(kdps.size() == 8); - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); - - // Build input data on CPU - Tensor i0cpu = make_ref(shape, DataType::Float32, cpuRuntime); - i0cpu->dataMalloc(); - i0cpu->setData(generator); - - // Build CUDA graph - Graph g = make_ref(cudaRuntime); - auto i0 = g->cloneTensor(i0cpu); - auto pool = g->addOp(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3], - kdps[4], kdps[5], kdps[6], kdps[7]); - - // allocate CUDA memory - g->dataMalloc(); - - // Execute on CUDA - cudaRuntime->run(g); - - // clone CUDA output to CPU - auto o0 = pool->getOutput(); - auto cpuo0 = o0->clone(cpuRuntime); - - // check results on CPU - EXPECT_TRUE(cpuo0->equalData(ansVec)); -} - -TEST(MaxPool, CuDNN) { - testPoolCudnn(IncrementalGenerator(), Shape{1, 2, 5, 5}, - KDPS{3, 3, 1, 1, 1, 1, 2, 2}, - ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31, - 33, 34, 41, 43, 44, 46, 48, 49}); -} - -TEST(AvgPool, CuDNN) { - testPoolCudnn( - IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2}, - ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000, - 8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778, - 23.666667, 37.0000, 25.666667, 19.111111, 29.666667, - 20.444444}); -} - } // namespace infini \ No newline at end of file