From af08df32d2fe631529a59e77c1c7a4cfc19fdf21 Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Tue, 23 Aug 2022 16:55:59 +0800 Subject: [PATCH] Extended DataType class and Runtime interaction (#9) * Add: DataType class * Add: data-type-oblivious tensor interface * Rename: copyBlobToCPU Co-authored-by: Liyan Zheng --- include/core/common.h | 5 +-- include/core/data_type.h | 34 +++++++++++++++++++ include/core/operator.h | 2 -- include/core/runtime.h | 47 +++++++++++++++++++++++--- include/core/tensor.h | 46 +++++++++++++------------- include/core/tensor_base.h | 34 +++---------------- include/cuda/cuda_runtime.h | 13 ++++++++ include/test.h | 10 ++---- src/core/blob.cc | 1 + src/core/graph.cc | 2 +- src/core/runtime.cc | 33 +++++++++++++++++++ src/core/tensor.cc | 66 +++++++++++++++++++++---------------- src/core/tensor_base.cc | 4 +-- src/cuda/cuda_runtime.cc | 2 ++ test/core/test_graph.cc | 8 ++--- test/operators/test_conv.cc | 33 ++++++++++--------- 16 files changed, 223 insertions(+), 117 deletions(-) create mode 100644 include/core/data_type.h diff --git a/include/core/common.h b/include/core/common.h index e296c331..541de97b 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -45,9 +45,10 @@ using HashType = uint64_t; // compatible with std::hash std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ "] Assertion failed (" + #name + "): " + #info)) #define _IT_ASSERT_1(name) _IT_ASSERT_2(name, ""); - #define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) -#define IT_TODO_HALT() IT_ASSERT(false, "Unimplemented") + +#define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented") +#define IT_TODO_HALT_MSG(msg) _IT_ASSERT_2(false, msg) #define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__) // Other utilities diff --git a/include/core/data_type.h b/include/core/data_type.h new file mode 100644 index 00000000..173600fb --- /dev/null +++ b/include/core/data_type.h @@ -0,0 +1,34 @@ +#include "core/common.h" + +namespace infini { + +class DataType { + public: + static const DataType Float32; + static const DataType UInt32; + static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)}; + static constexpr std::string_view names[]{"Float32", "UInt32"}; + + private: + int index; + + public: + constexpr DataType(int index) : index(index) {} + bool operator==(const DataType &rhs) const { return index == rhs.index; } + bool operator<(const DataType &rhs) const { return index < rhs.index; } + + template static DataType get() { + IT_TODO_HALT_MSG("Unsupported data type"); + } + size_t getSize() const { return sizePerElement[index]; } + string toString() const { return string(names[index]); } +}; + +inline const DataType DataType::Float32(0); +inline const DataType DataType::UInt32(1); +// Method definitions are out of the declaration due to GCC bug: +// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc +template <> inline DataType DataType::get() { return Float32; } +template <> inline DataType DataType::get() { return UInt32; } + +} // namespace infini \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h index 3be1c57f..73a578fc 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -37,8 +37,6 @@ enum class OpType { MemBound = 300, }; -enum class Device { CPU = 1, CUDA }; - using KernelAttrs = std::tuple; class OpRegistry { diff --git a/include/core/runtime.h b/include/core/runtime.h index da15578e..9df2f9c3 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -1,9 +1,33 @@ #pragma once -#include "core/graph.h" -#include "core/kernel.h" -#include "core/perf_engine.h" +#include "core/common.h" +#include "core/ref.h" +#include namespace infini { +/***************** Forward declaration begin *****************/ +class TensorBaseObj; +class TensorObj; +class OperatorObj; +class GraphObj; +class RuntimeObj; +class BlobObj; + +using TensorBase = Ref; +using Tensor = Ref; +using Operator = Ref; +using Graph = Ref; +using Runtime = Ref; +using Blob = Ref; +enum class OpType; + +using TensorVec = vector; +using OpVec = vector; + +using VType = uint32_t; + +enum class Device { CPU = 1, CUDA }; +/***************** Forward declaration end *****************/ + class RuntimeObj : public std::enable_shared_from_this { protected: Device device; @@ -37,17 +61,27 @@ class RuntimeObj : public std::enable_shared_from_this { */ double getPerfTime(const Graph &graph, bool profiling = false) const; Blob allocBlob(size_t size); + bool isCpu() const { return device == Device::CPU; } + bool isCuda() const { return device == Device::CUDA; } + void copyBlob(const TensorObj *dst, const TensorObj *src) const; protected: void printProfilingData(double totTime, const std::map &opTime, const std::map &opCnt) const; + virtual void copyBlobFromCPU(void *dst, void *src, size_t bytes) const = 0; + virtual void copyBlobToCPU(void *dst, void *src, size_t bytes) const = 0; + virtual void copyBlobInsideRuntime(void *dst, void *src, + size_t bytes) const = 0; }; -// TODO: change inheritance relation class CpuRuntimeObj : public RuntimeObj { public: CpuRuntimeObj() : RuntimeObj(Device::CPU) {} + static Ref &getInstance() { + static Ref instance = make_ref(); + return instance; + } void run(const Graph &graph, bool tune = false, bool profiling = false) const override; @@ -57,6 +91,11 @@ class CpuRuntimeObj : public RuntimeObj { return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t), sizeof(uint64_t)); }; + + void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override; + void copyBlobToCPU(void *dst, void *src, size_t bytes) const override; + void copyBlobInsideRuntime(void *dst, void *src, + size_t bytes) const override; }; } // namespace infini \ No newline at end of file diff --git a/include/core/tensor.h b/include/core/tensor.h index 67dddc95..eb8329e7 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -12,11 +12,12 @@ class TensorObj : public TensorBaseObj { Shape shape; public: - TensorObj(const Shape &shape, DataType dtype); + TensorObj(const Shape &shape, DataType dtype, Runtime runtime); virtual ~TensorObj() {} string toString() const override; size_t size() const; + size_t getBytes() const; Shape getDims() const { return shape; } @@ -24,39 +25,40 @@ class TensorObj : public TensorBaseObj { using TensorBaseObj::getData; VType getData(const Shape &pos) const; void dataMalloc(const Runtime &runtime); - // void copyData(VType *dptr); - template void copyData(const T *dptr); - void copyData(vector dataVector); - void copyData(vector dataVector); - void printData() const; - // TODO: merge these methods - bool equalData(const Tensor &rhs) const; - template bool equalData(const Tensor &rhs) const { + + template void copyData(const T *dptr) { + IT_ASSERT(DataType::get() == dtype); IT_ASSERT(data != nullptr); - IT_ASSERT(rhs->data != nullptr); - // TODO: deal with data type + if (!runtime->isCpu()) + IT_TODO_HALT(); auto ptr = data->getPtr(); - auto ptrRhs = rhs->data->getPtr(); - if (shape != rhs->getDims()) - return false; size_t sz = size(); - for (size_t i = 0; i < sz; ++i) - if (fabs(ptr[i] - ptrRhs[i]) / - std::max(fabs(ptr[i]), fabs(ptrRhs[i])) > - 1e-6) { - printf("Error on %lu: %f %f\n", i, ptr[i], ptrRhs[i]); - return false; - } - return true; +#pragma omp parallel for + for (size_t i = 0; i < sz; ++i) { + ptr[i] = dptr[i]; + } } + + template void copyData(vector dataVector) { + IT_ASSERT(DataType::get() == dtype); + IT_ASSERT(dataVector.size() >= size()); + copyData(dataVector.data()); + } + + void copyData(const Tensor &src) { runtime->copyBlob(this, src.get()); } void setData( const std::function &generator) const { generator(data->getPtr(), size(), dtype); } + void printData() const; + bool equalData(const Tensor &rhs) const; + private: void printDataFloat() const; void printDataUint32_t() const; + template bool equalDataInt(const Tensor &rhs) const; + template bool equalDataFloat(const Tensor &rhs) const; // void setDims(const Dim &dms) { dims = dms; } // bool dataRand(int seed = 0) { diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index d38458a7..ee33e662 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -1,34 +1,11 @@ #pragma once #include "core/blob.h" +#include "core/data_type.h" #include "core/object.h" -#include "core/ref.h" +#include "core/runtime.h" namespace infini { -class TensorBaseObj; -class TensorObj; -class OperatorObj; -class GraphObj; -class RuntimeObj; -class BlobObj; - -using TensorBase = Ref; -using Tensor = Ref; -using Operator = Ref; -using Graph = Ref; -using Runtime = Ref; -using Blob = Ref; - -using TensorVec = vector; -using OpVec = vector; - -using VType = uint32_t; - -enum class DataType { - Float32, - UInt32, -}; - class TensorBaseObj : public Object { public: // enum TensorType { @@ -45,12 +22,10 @@ class TensorBaseObj : public Object { vector> inputOf; WRef outputOf; Blob data; - // ComputeState computed; - // static int random_seed[256 * 16]; - // static bool random_inited; + Runtime runtime; public: - TensorBaseObj(int dim, DataType dtype); + TensorBaseObj(int dim, DataType dtype, Runtime runtime); virtual ~TensorBaseObj() {} void dataMalloc(const Blob &blob) { @@ -65,6 +40,7 @@ class TensorBaseObj : public Object { VType getData(size_t offset) const; DataType getDType() const { return dtype; } + Runtime getRuntime() const { return runtime; } // uint64_t getHash() const { return hash; } diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 9ad15f0c..ee625699 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -43,6 +43,19 @@ class CudaRuntimeObj : public RuntimeObj { return workspace; } + void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override { + checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyHostToDevice)); + } + + void copyBlobToCPU(void *dst, void *src, size_t bytes) const override { + checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost)); + } + + void copyBlobInsideRuntime(void *dst, void *src, + size_t bytes) const override { + checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice)); + } + private: void runWithoutSync(const Graph &graph) const; }; diff --git a/include/test.h b/include/test.h index 35131fae..052b6abd 100644 --- a/include/test.h +++ b/include/test.h @@ -14,16 +14,12 @@ class DataGenerator { public: virtual ~DataGenerator() {} void operator()(void *data, size_t size, DataType dataType) { - switch (dataType) { - case DataType::UInt32: + if (dataType == DataType::UInt32) fill(reinterpret_cast(data), size); - break; - case DataType::Float32: + else if (dataType == DataType::Float32) fill(reinterpret_cast(data), size); - break; - default: + else IT_TODO_HALT(); - } } }; diff --git a/src/core/blob.cc b/src/core/blob.cc index 72b00be4..a5a71f30 100644 --- a/src/core/blob.cc +++ b/src/core/blob.cc @@ -1,3 +1,4 @@ +#include "core/blob.h" #include "core/runtime.h" namespace infini { diff --git a/src/core/graph.cc b/src/core/graph.cc index 13e4678f..0ac489c4 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -19,7 +19,7 @@ void GraphObj::dataMalloc() { } Tensor GraphObj::addTensor(Shape dim, DataType dtype) { - Tensor tensor = make_ref(dim, dtype); + Tensor tensor = make_ref(dim, dtype, runtime); tensors.emplace_back(tensor); return tensor; } diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 38e3cc46..4b7e58f6 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -1,6 +1,9 @@ #include "core/runtime.h" #include "core/blob.h" +#include "core/kernel.h" +#include "core/perf_engine.h" #include +#include #include #include #include @@ -112,4 +115,34 @@ Blob RuntimeObj::allocBlob(size_t size) { return make_ref(shared_from_this(), alloc(size)); } +void RuntimeObj::copyBlob(const TensorObj *dst, const TensorObj *src) const { + void *dstPtr = dst->getDataRawPtr(); + void *srcPtr = src->getDataRawPtr(); + size_t bytes = dst->getBytes(); + auto dstRuntime = dst->getRuntime(); + auto srcRuntime = src->getRuntime(); + + if (dstRuntime.get() == srcRuntime.get()) { + dstRuntime->copyBlobInsideRuntime(dstPtr, srcPtr, bytes); + } else if (src->getRuntime()->isCpu()) { + dstRuntime->copyBlobFromCPU(dstPtr, srcPtr, bytes); + } else if (dst->getRuntime()->isCpu()) { + srcRuntime->copyBlobToCPU(dstPtr, srcPtr, bytes); + } else + IT_TODO_HALT(); +} + +void CpuRuntimeObj::copyBlobFromCPU(void *dst, void *src, size_t bytes) const { + copyBlobInsideRuntime(dst, src, bytes); +} + +void CpuRuntimeObj::copyBlobToCPU(void *dst, void *src, size_t bytes) const { + copyBlobInsideRuntime(dst, src, bytes); +} + +void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, void *src, + size_t bytes) const { + memcpy(dst, src, bytes); +} + } // namespace infini \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc index ec6991ae..11be19d7 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -4,8 +4,8 @@ namespace infini { -TensorObj::TensorObj(const Shape &shape, DataType dtype) - : TensorBaseObj(shape.size(), dtype), shape(shape) {} +TensorObj::TensorObj(const Shape &shape, DataType dtype, Runtime runtime) + : TensorBaseObj(shape.size(), dtype, runtime), shape(shape) {} VType TensorObj::getData(const Shape &pos) const { return getData(getOffset(pos)); @@ -34,29 +34,12 @@ size_t TensorObj::size() const { return ret; } -template void TensorObj::copyData(const T *dptr) { - // TODO: cuda - IT_ASSERT(data != nullptr); - auto ptr = data->getPtr(); - size_t sz = size(); -#pragma omp parallel for - for (size_t i = 0; i < sz; ++i) { - ptr[i] = dptr[i]; - } -} - -void TensorObj::copyData(vector dataVector) { - IT_ASSERT(dataVector.size() >= size()); - copyData(dataVector.data()); -} - -void TensorObj::copyData(vector dataVector) { - IT_ASSERT(dataVector.size() >= size()); - copyData(dataVector.data()); -} +size_t TensorObj::getBytes() const { return size() * dtype.getSize(); } void TensorObj::printData() const { IT_ASSERT(data != nullptr); + if (!runtime->isCpu()) + IT_TODO_HALT(); if (dtype == DataType::Float32) printDataFloat(); else if (dtype == DataType::UInt32) @@ -120,12 +103,9 @@ void TensorObj::printDataUint32_t() const { } } -bool TensorObj::equalData(const Tensor &rhs) const { - IT_ASSERT(data != nullptr); - IT_ASSERT(rhs->data != nullptr); - // TODO: deal with data type - auto ptr = data->getPtr(); - auto ptrRhs = rhs->data->getPtr(); +template bool TensorObj::equalDataInt(const Tensor &rhs) const { + auto ptr = data->getPtr(); + auto ptrRhs = rhs->data->getPtr(); if (shape != rhs->getDims()) return false; size_t sz = size(); @@ -135,6 +115,36 @@ bool TensorObj::equalData(const Tensor &rhs) const { return true; } +template bool TensorObj::equalDataFloat(const Tensor &rhs) const { + IT_ASSERT(data != nullptr); + IT_ASSERT(rhs->data != nullptr); + // TODO: deal with data type + auto ptr = data->getPtr(); + auto ptrRhs = rhs->data->getPtr(); + if (shape != rhs->getDims()) + return false; + size_t sz = size(); + for (size_t i = 0; i < sz; ++i) + if (fabs(ptr[i] - ptrRhs[i]) / std::max(fabs(ptr[i]), fabs(ptrRhs[i])) > + 1e-6) { + printf("Error on %lu: %f %f\n", i, ptr[i], ptrRhs[i]); + return false; + } + return true; +} + +bool TensorObj::equalData(const Tensor &rhs) const { + IT_ASSERT(data != nullptr); + IT_ASSERT(rhs->data != nullptr); + IT_ASSERT(getDType() == rhs->getDType()); + if (getDType() == DataType::UInt32) + return equalDataInt(rhs); + else if (getDType() == DataType::Float32) + return equalDataInt(rhs); + else + IT_TODO_HALT(); +} + void TensorObj::dataMalloc(const Runtime &runtime) { IT_ASSERT(data == nullptr); size_t bytesPerElement; diff --git a/src/core/tensor_base.cc b/src/core/tensor_base.cc index 9414361b..98409322 100644 --- a/src/core/tensor_base.cc +++ b/src/core/tensor_base.cc @@ -3,8 +3,8 @@ #include "core/runtime.h" namespace infini { -TensorBaseObj::TensorBaseObj(int dim, DataType dtype) - : dim(dim), dtype(dtype) {} +TensorBaseObj::TensorBaseObj(int dim, DataType dtype, Runtime runtime) + : dim(dim), dtype(dtype), runtime(runtime) {} VType TensorBaseObj::getData(size_t offset) const { // TODO: check cuda array diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 59fcabe2..1e91db50 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -1,4 +1,6 @@ #include "cuda/cuda_runtime.h" +#include "core/kernel.h" +#include "core/perf_engine.h" namespace infini { diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index ad1752f8..391acd3d 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -7,7 +7,7 @@ namespace infini { TEST(Graph, build_and_run) { - Runtime runtime = make_ref(); + Runtime runtime = CpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -18,14 +18,14 @@ TEST(Graph, build_and_run) { g->addOpWithOutputs(i0, w0, o0); runtime->run(g); // check answer - auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32); + auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32, runtime); ans->dataMalloc(runtime); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(o0->equalData(ans)); } TEST(Graph, perf_engine) { - Runtime runtime = make_ref(); + Runtime runtime = CpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); @@ -40,7 +40,7 @@ TEST(Graph, perf_engine) { EXPECT_GT(perfTime, 0); EXPECT_LT(perfTime, 0.01); // check answer - auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32); + auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32, runtime); ans->dataMalloc(runtime); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(matmul->getOutput()->equalData(ans)); diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc index f5c6f7d9..f528bb0a 100644 --- a/test/operators/test_conv.cc +++ b/test/operators/test_conv.cc @@ -8,7 +8,7 @@ namespace infini { TEST(Conv, ShapeInference) { - auto runtime = make_ref(); + Runtime runtime = CpuRuntimeObj::getInstance(); // Padding modes { Graph g = make_ref(runtime); @@ -43,7 +43,7 @@ TEST(Conv, ShapeInference) { } TEST(Conv, NaiveCPU) { - auto runtime = make_ref(); + Runtime runtime = CpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32); Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32); @@ -58,7 +58,8 @@ TEST(Conv, NaiveCPU) { EXPECT_GT(perfTime, 0); EXPECT_LT(perfTime, 0.1); // check answer - auto ans = make_ref(Shape{1, 2, 2, 2}, DataType::UInt32); + auto ans = + make_ref(Shape{1, 2, 2, 2}, DataType::UInt32, runtime); ans->dataMalloc(runtime); ans->copyData( vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); @@ -68,7 +69,7 @@ TEST(Conv, NaiveCPU) { void testConvCudnn( const std::function &generator, vector ansVec) { - auto cpuRuntime = make_ref(); + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); auto cudaRuntime = make_ref(); // Build CUDA graph Graph g = make_ref(cudaRuntime); @@ -80,23 +81,24 @@ void testConvCudnn( g->dataMalloc(); // Build input and output data on CPU - auto cpui0 = make_ref(Shape{1, 3, 4, 4}, DataType::Float32); + auto cpui0 = + make_ref(Shape{1, 3, 4, 4}, DataType::Float32, cpuRuntime); cpui0->dataMalloc(cpuRuntime); cpui0->setData(generator); - auto cpuw0 = make_ref(Shape{2, 3, 3, 3}, DataType::Float32); + auto cpuw0 = + make_ref(Shape{2, 3, 3, 3}, DataType::Float32, cpuRuntime); cpuw0->dataMalloc(cpuRuntime); cpuw0->setData(generator); - auto ans = make_ref(Shape{1, 2, 2, 2}, DataType::Float32); + auto ans = + make_ref(Shape{1, 2, 2, 2}, DataType::Float32, cpuRuntime); ans->dataMalloc(cpuRuntime); ans->copyData(ansVec); // Copy inputs from CPU to CUDA - cudaMemcpy(i0->getDataRawPtr(), cpui0->getDataRawPtr(), - cpui0->size() * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(w0->getDataRawPtr(), cpuw0->getDataRawPtr(), - cpuw0->size() * sizeof(float), cudaMemcpyHostToDevice); + i0->copyData(cpui0); + w0->copyData(cpuw0); // Execute on CUDA cudaRuntime->run(g); // double perfTime = cudaRuntime->getPerfTime(g); @@ -106,14 +108,13 @@ void testConvCudnn( // copy CUDA output to CPU auto o0 = conv->getOutput(); - auto cpuo0 = make_ref(Shape{1, 2, 2, 2}, DataType::Float32); + auto cpuo0 = + make_ref(Shape{1, 2, 2, 2}, DataType::Float32, cpuRuntime); cpuo0->dataMalloc(cpuRuntime); - cudaMemcpy(cpuo0->getDataRawPtr(), - conv->getOutput()->getDataRawPtr(), - cpuo0->size() * sizeof(float), cudaMemcpyDeviceToHost); + cpuo0->copyData(o0); // check results on CPU - EXPECT_TRUE(cpuo0->equalData(ans)); + EXPECT_TRUE(cpuo0->equalData(ans)); } TEST(Conv, cuDNN) {