From 93f86d3f4dbf10280fe78d7c329d06e0680735cf Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Thu, 25 Aug 2022 11:29:16 +0800 Subject: [PATCH] Simplify tensor transfer between CPU and CUDA (#10) * Add: OP infers data type & Graph clones tensor * Fix: vecToString format * Add: static assert for Tensor methods * Rename: getDataRawPtr -> getRawDataPtr Co-authored-by: Liyan Zheng --- include/core/common.h | 5 +-- include/core/graph.h | 6 ++++ include/core/operator.h | 2 ++ include/core/tensor.h | 31 ++++++++++++++--- include/core/tensor_base.h | 6 ++-- include/cuda/cuda_utility.h | 2 +- include/operators/conv.h | 2 +- include/operators/matmul.h | 2 +- src/core/graph.cc | 6 +++- src/core/operator.cc | 14 +++++++- src/core/runtime.cc | 4 +-- src/core/tensor.cc | 56 +++++++++++------------------- src/kernels/cpu/conv.cc | 6 ++-- src/kernels/cpu/matmul.cc | 7 ++-- src/kernels/cuda/conv.cc | 10 +++--- src/operators/conv.cc | 17 +++++---- src/operators/matmul.cc | 6 ++-- test/core/test_graph.cc | 4 +-- test/operators/test_conv.cc | 69 ++++++++++++++----------------------- 19 files changed, 137 insertions(+), 118 deletions(-) diff --git a/include/core/common.h b/include/core/common.h index 541de97b..7effff81 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -63,9 +63,10 @@ template std::string vecToString(const std::vector &vec) { ret.append("["); for (auto d : vec) { ret.append(std::to_string(d)); - ret.append(", "); + ret.append(","); } - ret.pop_back(); + if (!vec.empty()) + ret.pop_back(); ret.append("]"); return ret; } diff --git a/include/core/graph.h b/include/core/graph.h index 3e159792..cd675f58 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -17,6 +17,12 @@ class GraphObj : public Object { string toString() const override; Tensor addTensor(Shape dim, DataType dtype = DataType::UInt32); + Tensor cloneTensor(const Tensor &tensor) { + auto ret = addTensor(tensor->getDims(), tensor->getDType()); + ret->dataMalloc(); + ret->copyData(tensor); + return ret; + } /** * @brief Add an operator and create its outputs. Output tensor arguments diff --git a/include/core/operator.h b/include/core/operator.h index 73a578fc..c9523bf8 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -138,6 +138,7 @@ class OperatorObj : public Object { : type(opType), inputs(inputs), outputs(outputs) {} virtual optional> inferShape(const TensorVec &inputs) const = 0; + virtual vector inferDataType(const TensorVec &inputs) const; /** * @brief Constructs outputs (if requried) and check whether the operator is * valid. @@ -180,6 +181,7 @@ class OperatorObj : public Object { protected: optional> inferShape() const; + vector inferDataType() const; private: /** diff --git a/include/core/tensor.h b/include/core/tensor.h index eb8329e7..05925ec7 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -24,7 +24,7 @@ class TensorObj : public TensorBaseObj { size_t getOffset(const Shape &ds) const; using TensorBaseObj::getData; VType getData(const Shape &pos) const; - void dataMalloc(const Runtime &runtime); + void dataMalloc(); template void copyData(const T *dptr) { IT_ASSERT(DataType::get() == dtype); @@ -45,7 +45,8 @@ class TensorObj : public TensorBaseObj { copyData(dataVector.data()); } - void copyData(const Tensor &src) { runtime->copyBlob(this, src.get()); } + void copyData(const TensorObj *src); + void copyData(const Tensor &src) { copyData(src.get()); } void setData( const std::function &generator) const { generator(data->getPtr(), size(), dtype); @@ -54,11 +55,33 @@ class TensorObj : public TensorBaseObj { void printData() const; bool equalData(const Tensor &rhs) const; + template bool equalData(const vector &dataVector) { + IT_ASSERT(DataType::get() == dtype); + IT_ASSERT(size() == dataVector.size()); + return equalDataImpl(getRawDataPtr(), dataVector.data(), size()); + } + private: void printDataFloat() const; void printDataUint32_t() const; - template bool equalDataInt(const Tensor &rhs) const; - template bool equalDataFloat(const Tensor &rhs) const; + + template + bool equalDataImpl(const T *a, const T *b, size_t size) const { + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_integral_v) { + if (a[i] != b[i]) + return false; + } else if constexpr (std::is_floating_point_v) { + if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) > + 1e-6) { + printf("Error on %lu: %f %f\n", i, a[i], b[i]); + return false; + } + } else + static_assert(!sizeof(T), "Unsupported data type"); + } + return true; + } // void setDims(const Dim &dms) { dims = dms; } // bool dataRand(int seed = 0) { diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index ee33e662..b6118477 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -32,8 +32,10 @@ class TensorBaseObj : public Object { IT_ASSERT(data == nullptr); data = blob; } - Blob getDataPtr() const { return data; } - template T getDataRawPtr() const { + Blob getDataBlob() const { return data; } + template T getRawDataPtr() const { + static_assert(std::is_pointer_v, + "Raw data pointer has a type of pointer"); IT_ASSERT(data != nullptr); return data->getPtr(); } diff --git a/include/cuda/cuda_utility.h b/include/cuda/cuda_utility.h index 354a38ff..85e3478b 100644 --- a/include/cuda/cuda_utility.h +++ b/include/cuda/cuda_utility.h @@ -5,7 +5,7 @@ namespace infini { void cudaPrintFloat(float *x, int len); void cudaPrintTensor(const Tensor &tensor) { - cudaPrintFloat(tensor->getDataRawPtr(), tensor->size()); + cudaPrintFloat(tensor->getRawDataPtr(), tensor->size()); } } // namespace infini \ No newline at end of file diff --git a/include/operators/conv.h b/include/operators/conv.h index a2b80815..841d1351 100644 --- a/include/operators/conv.h +++ b/include/operators/conv.h @@ -36,7 +36,7 @@ class ConvObj : public OperatorObj { optional> inferShape(const TensorVec &inputs) const override; std::string toString() const override; - int numInputs() const override { return 3; } + int numInputs() const override { return 2; } int numOutputs() const override { return 1; } Tensor getBias() const { return inputs[2]; } diff --git a/include/operators/matmul.h b/include/operators/matmul.h index 328756b0..40a9af9f 100644 --- a/include/operators/matmul.h +++ b/include/operators/matmul.h @@ -33,7 +33,7 @@ class MatmulObj : public OperatorObj { std::string toString() const override; optional> inferShape(const TensorVec &inputs) const override; - int numInputs() const override { return 3; } + int numInputs() const override { return 2; } int numOutputs() const override { return 1; } Tensor getBias() const { return inputs[2]; } diff --git a/src/core/graph.cc b/src/core/graph.cc index 0ac489c4..0707920c 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -6,6 +6,10 @@ void GraphObj::updateConnection() { IT_TODO_HALT(); } string GraphObj::toString() const { std::ostringstream oss; + oss << "Graph Tensors:\n"; + for (const auto &tensor : tensors) + oss << tensor << "\n"; + oss << "Graph operators:\n"; for (const auto &op : ops) oss << op << "\n"; @@ -14,7 +18,7 @@ string GraphObj::toString() const { void GraphObj::dataMalloc() { for (auto &tensor : tensors) { - tensor->dataMalloc(runtime); + tensor->dataMalloc(); } } diff --git a/src/core/operator.cc b/src/core/operator.cc index e81c004b..ec7df9b1 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -57,9 +57,10 @@ bool OperatorObj::checkValid(GraphObj *graph) { if (shapes.size() != outputs.size()) return false; if (graph) { // if graph != nullptr, outputs should be created + auto dataTypes = inferDataType(); for (size_t i = 0; i < outputs.size(); i++) { IT_ASSERT(!outputs[i]); - outputs[i] = graph->addTensor(shapes[i]); + outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); } } else { // if graph is not empty, check outputs match inferred shapes for (size_t i = 0; i < shapes.size(); ++i) { @@ -74,4 +75,15 @@ optional> OperatorObj::inferShape() const { return inferShape(inputs); } +vector OperatorObj::inferDataType(const TensorVec &inputs) const { + auto dataType = inputs[0]->getDType(); + for (const auto &tensor : inputs) + IT_ASSERT(dataType == tensor->getDType()); + return vector(numOutputs(), dataType); +} + +vector OperatorObj::inferDataType() const { + return inferDataType(inputs); +} + } // namespace infini \ No newline at end of file diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 4b7e58f6..a97ef48b 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -116,8 +116,8 @@ Blob RuntimeObj::allocBlob(size_t size) { } void RuntimeObj::copyBlob(const TensorObj *dst, const TensorObj *src) const { - void *dstPtr = dst->getDataRawPtr(); - void *srcPtr = src->getDataRawPtr(); + void *dstPtr = dst->getRawDataPtr(); + void *srcPtr = src->getRawDataPtr(); size_t bytes = dst->getBytes(); auto dstRuntime = dst->getRuntime(); auto srcRuntime = src->getRuntime(); diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 11be19d7..18fbdf3b 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -11,7 +11,9 @@ VType TensorObj::getData(const Shape &pos) const { return getData(getOffset(pos)); } -string TensorObj::toString() const { return "Tensor " + std::to_string(guid); } +string TensorObj::toString() const { + return "Tensor " + std::to_string(guid) + " shape " + vecToString(shape); +} size_t TensorObj::getOffset(const Shape &pos) const { auto nDim = pos.size(); @@ -103,50 +105,28 @@ void TensorObj::printDataUint32_t() const { } } -template bool TensorObj::equalDataInt(const Tensor &rhs) const { - auto ptr = data->getPtr(); - auto ptrRhs = rhs->data->getPtr(); - if (shape != rhs->getDims()) - return false; - size_t sz = size(); - for (size_t i = 0; i < sz; ++i) - if (ptr[i] != ptrRhs[i]) - return false; - return true; -} - -template bool TensorObj::equalDataFloat(const Tensor &rhs) const { - IT_ASSERT(data != nullptr); - IT_ASSERT(rhs->data != nullptr); - // TODO: deal with data type - auto ptr = data->getPtr(); - auto ptrRhs = rhs->data->getPtr(); - if (shape != rhs->getDims()) - return false; - size_t sz = size(); - for (size_t i = 0; i < sz; ++i) - if (fabs(ptr[i] - ptrRhs[i]) / std::max(fabs(ptr[i]), fabs(ptrRhs[i])) > - 1e-6) { - printf("Error on %lu: %f %f\n", i, ptr[i], ptrRhs[i]); - return false; - } - return true; -} - bool TensorObj::equalData(const Tensor &rhs) const { IT_ASSERT(data != nullptr); IT_ASSERT(rhs->data != nullptr); IT_ASSERT(getDType() == rhs->getDType()); + IT_ASSERT(runtime->isCpu()); + IT_ASSERT(rhs->getRuntime()->isCpu()); + if (shape != rhs->getDims()) + return false; if (getDType() == DataType::UInt32) - return equalDataInt(rhs); + return equalDataImpl(getRawDataPtr(), + rhs->getRawDataPtr(), size()); else if (getDType() == DataType::Float32) - return equalDataInt(rhs); + return equalDataImpl(getRawDataPtr(), + rhs->getRawDataPtr(), size()); else IT_TODO_HALT(); } -void TensorObj::dataMalloc(const Runtime &runtime) { - IT_ASSERT(data == nullptr); +void TensorObj::dataMalloc() { + if (data != nullptr) + return; + // IT_ASSERT(data == nullptr); size_t bytesPerElement; if (getDType() == DataType::Float32) bytesPerElement = sizeof(float); @@ -155,4 +135,10 @@ void TensorObj::dataMalloc(const Runtime &runtime) { data = runtime->allocBlob(size() * bytesPerElement); } +void TensorObj::copyData(const TensorObj *src) { + IT_ASSERT(dtype == src->getDType()); + IT_ASSERT(size() == src->size()); + runtime->copyBlob(this, src); +} + }; // namespace infini \ No newline at end of file diff --git a/src/kernels/cpu/conv.cc b/src/kernels/cpu/conv.cc index ac8f7b43..c8ccab4c 100644 --- a/src/kernels/cpu/conv.cc +++ b/src/kernels/cpu/conv.cc @@ -7,9 +7,9 @@ template class NaiveConv : public Kernel { void compute(const Operator &_op, const PerfRecord &record, const RuntimeObj *context) const override { auto op = as(_op); - T *iptr = op->getInputs(0)->getDataRawPtr(); - T *wptr = op->getInputs(1)->getDataRawPtr(); - T *optr = op->getOutput()->getDataRawPtr(); + T *iptr = op->getInputs(0)->getRawDataPtr(); + T *wptr = op->getInputs(1)->getRawDataPtr(); + T *optr = op->getOutput()->getRawDataPtr(); auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); int cpg = op->getChannelPerGroup(); diff --git a/src/kernels/cpu/matmul.cc b/src/kernels/cpu/matmul.cc index d7182954..42141d4e 100644 --- a/src/kernels/cpu/matmul.cc +++ b/src/kernels/cpu/matmul.cc @@ -7,9 +7,10 @@ template class NaiveMatmul : public Kernel { void compute(const Operator &_op, const PerfRecord &record, const RuntimeObj *context) const override { auto op = as(_op); - T *A = op->getInputs(0)->getDataRawPtr(); - T *B = op->getInputs(1)->getDataRawPtr(); - T *C = op->getOutput()->getDataRawPtr(); + IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet."); + T *A = op->getInputs(0)->getRawDataPtr(); + T *B = op->getInputs(1)->getRawDataPtr(); + T *C = op->getOutput()->getRawDataPtr(); IT_ASSERT(op->getTransA() == false && op->getTransB() == false); IT_ASSERT(op->getAct() == ActType::None); IT_ASSERT(op->getB() == 1); diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index 8289ea4f..c31868bb 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -26,12 +26,12 @@ class convCudnn : public Kernel { bool cuDNNUnfused(const Ref &op, const ConvCuDnnPerfRecord &record, const CudaRuntimeObj *context) const { cudnnStatus_t stat; - void *const inData = (op->getInputs(0)->getDataRawPtr()); - void *const knData = (op->getInputs(1)->getDataRawPtr()); - if (op->getInputs(2) != nullptr) + void *const inData = (op->getInputs(0)->getRawDataPtr()); + void *const knData = (op->getInputs(1)->getRawDataPtr()); + if (op->getInputs().size() > 2) // Bias is not supported yet IT_TODO_HALT(); - // void *const biasData = (op->getInputs(2)->getDataRawPtr()); - void *const outData = (op->getOutput()->getDataRawPtr()); + // void *const biasData = (op->getInputs(2)->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); const int cpg = op->getChannelPerGroup(); diff --git a/src/operators/conv.cc b/src/operators/conv.cc index e2f9dc27..8c63f241 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -3,20 +3,19 @@ namespace infini { ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, - int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias, - ActType act) - : OperatorObj(OpType::Conv, {input, weight, bias}, {output}), ph(ph), - pw(pw), sh(sh), sw(sw), dh(dh), dw(dw), act(act), - padding(PaddingMode::Other) { + int ph, int pw, int sh, int sw, int dh, int dw, + [[maybe_unused]] Tensor bias, ActType act) + : OperatorObj(OpType::Conv, {input, weight}, {output}), ph(ph), pw(pw), + sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(PaddingMode::Other) { setAuxilaryAttributes(PaddingMode::Other); IT_ASSERT(checkValid(graph)); } ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, - PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias, - ActType act) - : OperatorObj(OpType::Conv, {input, weight, bias}, {output}), ph(-1), - pw(-1), sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(mode) { + PaddingMode mode, int sh, int sw, int dh, int dw, + [[maybe_unused]] Tensor bias, ActType act) + : OperatorObj(OpType::Conv, {input, weight}, {output}), ph(-1), pw(-1), + sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(mode) { IT_ASSERT(mode != PaddingMode::Other); setAuxilaryAttributes(mode); IT_ASSERT(checkValid(graph)); diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 20f60914..db109192 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -3,9 +3,9 @@ namespace infini { MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, - bool transB, Tensor bias, ActType act) - : OperatorObj(OpType::Matmul, {A, B, bias}, {C}), transA(transA), - transB(transB), act(act), b(A->getDims()[0]), + bool transB, [[maybe_unused]] Tensor bias, ActType act) + : OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB), + act(act), b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]), n(transB ? B->getDims()[1] : B->getDims()[2]), k(transA ? A->getDims()[1] : A->getDims()[2]) { diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 391acd3d..a91d8096 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -19,7 +19,7 @@ TEST(Graph, build_and_run) { runtime->run(g); // check answer auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32, runtime); - ans->dataMalloc(runtime); + ans->dataMalloc(); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(o0->equalData(ans)); } @@ -41,7 +41,7 @@ TEST(Graph, perf_engine) { EXPECT_LT(perfTime, 0.01); // check answer auto ans = make_ref(Shape{1, 2, 4}, DataType::UInt32, runtime); - ans->dataMalloc(runtime); + ans->dataMalloc(); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(matmul->getOutput()->equalData(ans)); } diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc index f528bb0a..f6397c5c 100644 --- a/test/operators/test_conv.cc +++ b/test/operators/test_conv.cc @@ -60,7 +60,7 @@ TEST(Conv, NaiveCPU) { // check answer auto ans = make_ref(Shape{1, 2, 2, 2}, DataType::UInt32, runtime); - ans->dataMalloc(runtime); + ans->dataMalloc(); ans->copyData( vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); EXPECT_TRUE(conv->getOutput()->equalData(ans)); @@ -69,52 +69,35 @@ TEST(Conv, NaiveCPU) { void testConvCudnn( const std::function &generator, vector ansVec) { - Runtime cpuRuntime = CpuRuntimeObj::getInstance(); - auto cudaRuntime = make_ref(); + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); // Build CUDA graph - Graph g = make_ref(cudaRuntime); - Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::Float32); - Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::Float32); - auto conv = g->addOp(i0, w0, nullptr, 1, 1, 2, 1, 1, 2); - + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2); // allocate CUDA memory - g->dataMalloc(); - - // Build input and output data on CPU - auto cpui0 = - make_ref(Shape{1, 3, 4, 4}, DataType::Float32, cpuRuntime); - cpui0->dataMalloc(cpuRuntime); - cpui0->setData(generator); - - auto cpuw0 = - make_ref(Shape{2, 3, 3, 3}, DataType::Float32, cpuRuntime); - cpuw0->dataMalloc(cpuRuntime); - cpuw0->setData(generator); - - auto ans = - make_ref(Shape{1, 2, 2, 2}, DataType::Float32, cpuRuntime); - ans->dataMalloc(cpuRuntime); - ans->copyData(ansVec); - - // Copy inputs from CPU to CUDA - i0->copyData(cpui0); - w0->copyData(cpuw0); + gCuda->dataMalloc(); // Execute on CUDA - cudaRuntime->run(g); - // double perfTime = cudaRuntime->getPerfTime(g); - // // The example Conv takes 0.015ms with one core - // EXPECT_GT(perfTime, 0); - // EXPECT_LT(perfTime, 0.1); - - // copy CUDA output to CPU - auto o0 = conv->getOutput(); - auto cpuo0 = - make_ref(Shape{1, 2, 2, 2}, DataType::Float32, cpuRuntime); - cpuo0->dataMalloc(cpuRuntime); - cpuo0->copyData(o0); - + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); // check results on CPU - EXPECT_TRUE(cpuo0->equalData(ans)); + EXPECT_TRUE(o0Cpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + gCuda->print(); } TEST(Conv, cuDNN) {