From beba9c16c4f875cac74ae3d7f6fd73e8c38f48e5 Mon Sep 17 00:00:00 2001 From: wanghailu Date: Fri, 21 Oct 2022 14:46:42 +0800 Subject: [PATCH] add code for test resnet --- test/kernels/cuda/test_compare.cc | 74 +++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 test/kernels/cuda/test_compare.cc diff --git a/test/kernels/cuda/test_compare.cc b/test/kernels/cuda/test_compare.cc new file mode 100644 index 00000000..f12b27a9 --- /dev/null +++ b/test/kernels/cuda/test_compare.cc @@ -0,0 +1,74 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/element_wise.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +template +void testResnet( + const std::function &generator, + const Shape &shape) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor acpu = make_ref(shape, DataType::Float32, cpuRuntime); + acpu->dataMalloc(); + acpu->setData(generator); + + Tensor bcpu = make_ref(shape, DataType::Float32, cpuRuntime); + bcpu->dataMalloc(); + bcpu->setData(generator); + + Tensor ccpu = make_ref(shape, DataType::Float32, cpuRuntime); + ccpu->dataMalloc(); + ccpu->setData(generator); + + Graph g = make_ref(cudaRuntime); + Graph cg = make_ref(cpuRuntime); + + auto a = g->cloneTensor(acpu); + auto b = g->cloneTensor(bcpu); + auto c = g->cloneTensor(ccpu); + + auto op = g->addOpWithOutputs(a, b, c); + op = g->addOpWithOutputs(c, b, c); + auto cop = cg->addOpWithOutputs(acpu, bcpu, ccpu); + cop = cg->addOpWithOutputs(ccpu, bcpu, ccpu); + + // allocate CUDA memory + g->dataMalloc(); + cg->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + cpuRuntime->run(cg); + + // clone CUDA output to CPU + auto gpu2cpu = c->clone(cpuRuntime); + // cudaPrintTensor(c); + // check results on CPU + ccpu->printData(); + EXPECT_TRUE(gpu2cpu->equalData(ccpu)); +} + +TEST(cuDNN_ElementWise, run) { + testResnet( + IncrementalGenerator(), Shape{1, 2, 2, 3}); + testResnet( + IncrementalGenerator(), Shape{1, 2, 2, 3}); + testResnet( + IncrementalGenerator(), Shape{1, 2, 2, 3}); + // testResnet( + // IncrementalGenerator(), Shape{1, 2, 2, 3}); + // testResnet( + // IncrementalGenerator(), Shape{1, 2, 2, 3}); +} + +} // namespace infini