From 3d122aebfe6d091d3452bb763aaf522853ab7789 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 11:19:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=AF=BC=E5=87=BA?= =?UTF-8?q?=E6=B5=AE=E7=82=B9=E5=90=91=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/tensor.h | 24 +++++++++++++++++++++++- src/ffi/ffi_infinitensor.cc | 2 +- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/include/core/tensor.h b/include/core/tensor.h index 9e835a6b..422355e9 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,6 +1,11 @@ #pragma once #include "core/tensor_base.h" #include +#include + +#if USE_CUDA +#include "cuda/cuda_runtime.h" +#endif namespace infini { @@ -55,7 +60,6 @@ class TensorObj : public TensorBaseObj { obj->outputOf.reset(); return obj; } - // TODO: clarify whether clone copies data Tensor clone(Runtime runtime) const { auto obj = make_ref(*this); obj->runtime = runtime; @@ -68,6 +72,24 @@ class TensorObj : public TensorBaseObj { } return obj; } + inline std::vector cloneFloats() const { + IT_ASSERT(data != nullptr); + IT_ASSERT(getDType() == DataType::Float32); + std::vector ans(size()); + auto src = getRawDataPtr(); + auto dst = ans.data(); + auto bytes = getBytes(); + if (runtime->isCpu()) { + memcpy(dst, src, bytes); + } else { +#if USE_CUDA + cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost); +#else + IT_TODO_HALT(); +#endif + } + return ans; + } void printData() const; bool equalData(const Tensor &rhs) const; diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 402d5306..2be4971b 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -146,7 +146,7 @@ void init_graph_builder(py::module &m) { #endif py::class_>(m, "Tensor") .def("shape", &TensorObj::getDims, policy::move) - .def("printData", &TensorObj::printData, policy::automatic) + .def("cloneFloats", &TensorObj::cloneFloats, policy::move) .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic)