forked from jiuyuan/InfiniTensor
Fix CMake USE_CUDA (#36)
* Fix: build lib without cuda * Chore: rename GBMM and G2BMM files * Fix: seperate CUDA tests from operator tests * Fix: CMake CMP0104 * Chore: fix typo * Chore: remove unused headers Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
8f67a5cc76
commit
2f8f706f1c
|
@ -73,8 +73,15 @@ if(BUILD_TEST)
|
|||
include_directories(3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
|
||||
# Source files
|
||||
file(GLOB_RECURSE SRC src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc)
|
||||
|
||||
if(USE_CUDA)
|
||||
file(GLOB_RECURSE SRC_CUDA src/cuda/*.cc src/cuda/*.cu src/kernels/cuda/*.cc src/kernels/cuda/*.cu)
|
||||
list (APPEND SRC ${SRC_CUDA})
|
||||
endif()
|
||||
|
||||
# Libraries
|
||||
add_library(InfiniTensor SHARED ${SRC})
|
||||
if(USE_PROTOBUF)
|
||||
target_link_libraries(InfiniTensor tensor_proto)
|
||||
|
@ -93,6 +100,8 @@ if(USE_CUDA)
|
|||
set(CMAKE_CUDA_HOST_COMPILER
|
||||
${CMAKE_CXX_COMPILER}
|
||||
CACHE STRING "Set cuda host compiler path")
|
||||
# CMP0104 requires CUDA_ARCHITECTURES
|
||||
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80")
|
||||
enable_language(CUDA)
|
||||
# TODO: find_package seems unnecessary for CMake >= 3.8
|
||||
find_package(CUDA REQUIRED)
|
||||
|
@ -123,6 +132,9 @@ if(BUILD_TEST)
|
|||
if(BUILD_TEST_CORE)
|
||||
build_test(test/core/*.cc)
|
||||
build_test(test/operators/*.cc)
|
||||
if (USE_CUDA)
|
||||
build_test(test/kernels/cuda/*.cc)
|
||||
endif()
|
||||
endif()
|
||||
if(BUILD_TEST_PET)
|
||||
build_test(test/pet/*.cc)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/common.h"
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cudnn.h>
|
||||
#include <curand.h>
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#ifndef CUSTOM_OPS_CUH
|
||||
#define CUSTOM_OPS_CUH
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
|
@ -5797,6 +5796,4 @@ inline void sgbmml(float *__restrict__ q, float *__restrict__ k,
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
||||
#endif // CUSTOM_OPS_CUH
|
||||
} // namespace infini
|
|
@ -1,5 +1,4 @@
|
|||
#ifndef CUSTOM_OPS_H
|
||||
#define CUSTOM_OPS_H
|
||||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -10,5 +9,3 @@ void _sgbmml(float *__restrict__ q, float *__restrict__ k,
|
|||
float *__restrict__ y, int bs, int n, int m, int w, int d);
|
||||
|
||||
} // namespace infini
|
||||
|
||||
#endif // CUSTOM_OPS_H
|
|
@ -4,10 +4,6 @@
|
|||
#include "core/perf_engine.h"
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <cuda.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cudnn.h>
|
||||
#include <curand.h>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "operators/G2BMM.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "custom_ops.h"
|
||||
#include "cuda/gbmm_g2bmm.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "operators/GBMM.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "custom_ops.h"
|
||||
#include "cuda/gbmm_g2bmm.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "custom_ops.cuh"
|
||||
#include "custom_ops.h"
|
||||
#include "cuda/gbmm_g2bmm.cuh"
|
||||
#include "cuda/gbmm_g2bmm.h"
|
||||
|
||||
namespace infini {
|
||||
|
|
@ -20,10 +20,12 @@
|
|||
#include "nnet/Visitor/HashVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/Visitor/Serializer.h"
|
||||
#include "nnet/test.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// avoid dependence of "nnet/test.h"
|
||||
bool checkExprsEquvivalence(VecExpr exprs);
|
||||
|
||||
class SaveStateGuard {
|
||||
Derivator &derivator;
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "operators/G2BMM.h"
|
||||
#include "custom_ops.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "operators/GBMM.h"
|
||||
#include "custom_ops.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
namespace infini {
|
||||
using ExpectOutput = vector<float>;
|
||||
|
||||
TEST(G2BMM, ShapeInference) {
|
||||
TEST(CUDA_G2BMM, ShapeInference) {
|
||||
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
|
@ -9,7 +9,7 @@
|
|||
namespace infini {
|
||||
using ExpectOutput = vector<float>;
|
||||
|
||||
TEST(GBMM, ShapeInference) {
|
||||
TEST(CUDA_GBMM, ShapeInference) {
|
||||
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
|
@ -0,0 +1,79 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, run) {
|
||||
testConvCudnn(OneGenerator(),
|
||||
vector<float>{12, 12, 18, 18, 12, 12, 18, 18});
|
||||
testConvCudnn(
|
||||
IncrementalGenerator(),
|
||||
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,89 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvTransposedCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr,
|
||||
padding, padding, stride,
|
||||
stride, dilation, dilation);
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposed, run) {
|
||||
testConvTransposedCudnn(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
12., 18., 16., 8., 30., 36., 42.,
|
||||
32., 16., 54., 60., 66., 48., 24.,
|
||||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposed, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 448, 2, 2}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({448, 256, 4, 4}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
||||
ASSERT_TRUE(perfData.has_value());
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,68 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
using ExpectOutput = vector<float>;
|
||||
template <class T>
|
||||
void testElementWiseCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor acpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
acpu->dataMalloc();
|
||||
acpu->setData(generator);
|
||||
|
||||
Tensor bcpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
bcpu->dataMalloc();
|
||||
bcpu->setData(generator);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto a = g->cloneTensor(acpu);
|
||||
auto b = g->cloneTensor(bcpu);
|
||||
auto op = g->addOp<T>(a, b, nullptr);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto c = op->getOutput();
|
||||
auto ccpu = c->clone(cpuRuntime);
|
||||
// cudaPrintTensor(c);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ccpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(cuDNN_ElementWise, run) {
|
||||
testElementWiseCudnn<AddObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
|
||||
testElementWiseCudnn<SubObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
testElementWiseCudnn<MulObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
|
||||
|
||||
testElementWiseCudnn<DivObj>(
|
||||
OneGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
testElementWiseCudnn<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
|
||||
ExpectOutput{1, 1, 4, 27});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,76 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
using ExpectOutput = vector<float>;
|
||||
|
||||
void testMatmulCuda(
|
||||
const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||
const ExpectOutput &ansVec) {
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
ACpu->setData(generatorA);
|
||||
BCpu->setData(generatorB);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||
auto matmul =
|
||||
gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr, transA, transB);
|
||||
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto CCpu = gCpu->cloneTensor(matmul->getOutput());
|
||||
// CCpu->printData();
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(CCpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
// gCuda->print();
|
||||
}
|
||||
|
||||
TEST(cuBLAS_Matmul, run) {
|
||||
testMatmulCuda(IncrementalGenerator(), OneGenerator(), false, false,
|
||||
Shape{1, 3, 5}, Shape{1, 5, 2},
|
||||
ExpectOutput{10, 10, 35, 35, 60, 60});
|
||||
testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false,
|
||||
Shape{2, 3, 4}, Shape{2, 3, 2},
|
||||
ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424,
|
||||
475, 448, 502, 472, 529});
|
||||
}
|
||||
|
||||
TEST(cuBLAS_Matmul, tune) {
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
ACpu->setData(IncrementalGenerator());
|
||||
BCpu->setData(IncrementalGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
|
||||
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda, true);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,61 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
using KDPS = vector<int>;
|
||||
using ExpectOutput = vector<float>;
|
||||
|
||||
template <class T>
|
||||
void testPoolCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) {
|
||||
EXPECT_TRUE(kdps.size() == 8);
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor i0cpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
i0cpu->dataMalloc();
|
||||
i0cpu->setData(generator);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto i0 = g->cloneTensor(i0cpu);
|
||||
auto pool = g->addOp<T>(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3],
|
||||
kdps[4], kdps[5], kdps[6], kdps[7]);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto o0 = pool->getOutput();
|
||||
auto cpuo0 = o0->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(cpuo0->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(cuDNN_MaxPool, run) {
|
||||
testPoolCudnn<MaxPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5},
|
||||
KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||
ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31,
|
||||
33, 34, 41, 43, 44, 46, 48, 49});
|
||||
}
|
||||
|
||||
TEST(cuDNN_AvgPool, run) {
|
||||
testPoolCudnn<AvgPoolObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||
ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000,
|
||||
8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778,
|
||||
23.666667, 37.0000, 25.666667, 19.111111, 29.666667,
|
||||
20.444444});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -39,7 +39,7 @@ void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(Unary, CuDNN) {
|
||||
TEST(cuDNN_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SoftmaxObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
|
@ -1,8 +1,6 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -68,71 +66,4 @@ TEST(Conv, NaiveCPU) {
|
|||
EXPECT_TRUE(conv->getOutput()->equalData(ans));
|
||||
}
|
||||
|
||||
void testConvCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
}
|
||||
|
||||
TEST(Conv, cuDNN) {
|
||||
testConvCudnn(OneGenerator(),
|
||||
vector<float>{12, 12, 18, 18, 12, 12, 18, 18});
|
||||
testConvCudnn(
|
||||
IncrementalGenerator(),
|
||||
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||
}
|
||||
|
||||
TEST(Conv, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
}
|
||||
} // namespace infini
|
|
@ -2,8 +2,6 @@
|
|||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -36,80 +34,4 @@ TEST(ConvTransposed, ShapeInference) {
|
|||
}
|
||||
}
|
||||
|
||||
void testConvTransposedCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr,
|
||||
padding, padding, stride,
|
||||
stride, dilation, dilation);
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(ConvTransposed, cuDNN) {
|
||||
testConvTransposedCudnn(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
12., 18., 16., 8., 30., 36., 42.,
|
||||
32., 16., 54., 60., 66., 48., 24.,
|
||||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(ConvTransposed, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 448, 2, 2}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({448, 256, 4, 4}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
||||
ASSERT_TRUE(perfData.has_value());
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -1,8 +1,6 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -20,103 +18,5 @@ TEST(ElementWise, ShapeInference) {
|
|||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
||||
}
|
||||
}
|
||||
/*
|
||||
template <typename T>
|
||||
void test_element_wise(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const vector<uint32_t> &ans) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
||||
Tensor i1 = g->addTensor({2, 3, 1, 2}, DataType::UInt32);
|
||||
auto op = g->addOp<T>(i0, i1, nullptr);
|
||||
|
||||
g->dataMalloc();
|
||||
i0->setData(generator);
|
||||
i1->setData(generator);
|
||||
runtime->run(g, true, true);
|
||||
// check answer
|
||||
EXPECT_TRUE(op->getOutput()->equalData(ans));
|
||||
}
|
||||
|
||||
TEST(ElementWise, NaiveCPU) {
|
||||
test_element_wise<AddObj>(IncrementalGenerator(),
|
||||
vector<uint32_t>{0, 2, 2, 4, 6, 8, 8, 10,
|
||||
12, 14, 14, 16, 6, 8, 8, 10,
|
||||
12, 14, 14, 16, 18, 20, 20, 22});
|
||||
test_element_wise<SubObj>(
|
||||
IncrementalGenerator(),
|
||||
vector<uint32_t>{0, 0, 2, 2,
|
||||
2, 2, 4, 4,
|
||||
4, 4, 6, 6,
|
||||
4294967290, 4294967290, 4294967292, 4294967292,
|
||||
4294967292, 4294967292, 4294967294, 4294967294,
|
||||
4294967294, 4294967294, 0, 0});
|
||||
test_element_wise<MulObj>(
|
||||
IncrementalGenerator(),
|
||||
vector<uint32_t>{0, 1, 0, 3, 8, 15, 12, 21, 32, 45, 40, 55,
|
||||
0, 7, 12, 21, 32, 45, 48, 63, 80, 99, 100, 121});
|
||||
test_element_wise<DivObj>(OneGenerator(),
|
||||
vector<uint32_t>{
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
});
|
||||
}
|
||||
*/
|
||||
|
||||
template <class T>
|
||||
void testElementWiseCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor acpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
acpu->dataMalloc();
|
||||
acpu->setData(generator);
|
||||
|
||||
Tensor bcpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
bcpu->dataMalloc();
|
||||
bcpu->setData(generator);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto a = g->cloneTensor(acpu);
|
||||
auto b = g->cloneTensor(bcpu);
|
||||
auto op = g->addOp<T>(a, b, nullptr);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto c = op->getOutput();
|
||||
auto ccpu = c->clone(cpuRuntime);
|
||||
// cudaPrintTensor(c);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ccpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(ElementWise, CuDNN) {
|
||||
testElementWiseCudnn<AddObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
|
||||
testElementWiseCudnn<SubObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
testElementWiseCudnn<MulObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
|
||||
|
||||
testElementWiseCudnn<DivObj>(
|
||||
OneGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
testElementWiseCudnn<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
|
||||
ExpectOutput{1, 1, 4, 27});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -2,8 +2,6 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -30,66 +28,5 @@ TEST(Matmul, ShapeInference) {
|
|||
EXPECT_EQ(C->getDims(), (Shape{3, 4, 2}));
|
||||
}
|
||||
}
|
||||
void testMatmulCuda(
|
||||
const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||
const ExpectOutput &ansVec) {
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
ACpu->setData(generatorA);
|
||||
BCpu->setData(generatorB);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||
auto matmul =
|
||||
gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr, transA, transB);
|
||||
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto CCpu = gCpu->cloneTensor(matmul->getOutput());
|
||||
// CCpu->printData();
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(CCpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
// gCuda->print();
|
||||
}
|
||||
|
||||
TEST(Matmul, cuBlas) {
|
||||
testMatmulCuda(IncrementalGenerator(), OneGenerator(), false, false,
|
||||
Shape{1, 3, 5}, Shape{1, 5, 2},
|
||||
ExpectOutput{10, 10, 35, 35, 60, 60});
|
||||
testMatmulCuda(IncrementalGenerator(), IncrementalGenerator(), true, false,
|
||||
Shape{2, 3, 4}, Shape{2, 3, 2},
|
||||
ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424,
|
||||
475, 448, 502, 472, 529});
|
||||
}
|
||||
|
||||
TEST(Matmul, tune) {
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
ACpu->setData(IncrementalGenerator());
|
||||
BCpu->setData(IncrementalGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
|
||||
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda, true);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -1,7 +1,5 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "test.h"
|
||||
|
||||
|
@ -70,53 +68,4 @@ TEST(AvgPool, NaiveCPU) {
|
|||
EXPECT_LT(perfTime, 5);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void testPoolCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) {
|
||||
EXPECT_TRUE(kdps.size() == 8);
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor i0cpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
i0cpu->dataMalloc();
|
||||
i0cpu->setData(generator);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto i0 = g->cloneTensor(i0cpu);
|
||||
auto pool = g->addOp<T>(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3],
|
||||
kdps[4], kdps[5], kdps[6], kdps[7]);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto o0 = pool->getOutput();
|
||||
auto cpuo0 = o0->clone(cpuRuntime);
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(cpuo0->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(MaxPool, CuDNN) {
|
||||
testPoolCudnn<MaxPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5},
|
||||
KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||
ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31,
|
||||
33, 34, 41, 43, 44, 46, 48, 49});
|
||||
}
|
||||
|
||||
TEST(AvgPool, CuDNN) {
|
||||
testPoolCudnn<AvgPoolObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||
ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000,
|
||||
8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778,
|
||||
23.666667, 37.0000, 25.666667, 19.111111, 29.666667,
|
||||
20.444444});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue