diff --git a/include/core/tensor.h b/include/core/tensor.h index 9e835a6b..422355e9 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,6 +1,11 @@ #pragma once #include "core/tensor_base.h" #include +#include + +#if USE_CUDA +#include "cuda/cuda_runtime.h" +#endif namespace infini { @@ -55,7 +60,6 @@ class TensorObj : public TensorBaseObj { obj->outputOf.reset(); return obj; } - // TODO: clarify whether clone copies data Tensor clone(Runtime runtime) const { auto obj = make_ref(*this); obj->runtime = runtime; @@ -68,6 +72,24 @@ class TensorObj : public TensorBaseObj { } return obj; } + inline std::vector cloneFloats() const { + IT_ASSERT(data != nullptr); + IT_ASSERT(getDType() == DataType::Float32); + std::vector ans(size()); + auto src = getRawDataPtr(); + auto dst = ans.data(); + auto bytes = getBytes(); + if (runtime->isCpu()) { + memcpy(dst, src, bytes); + } else { +#if USE_CUDA + cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost); +#else + IT_TODO_HALT(); +#endif + } + return ans; + } void printData() const; bool equalData(const Tensor &rhs) const; diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 402d5306..2be4971b 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -146,7 +146,7 @@ void init_graph_builder(py::module &m) { #endif py::class_>(m, "Tensor") .def("shape", &TensorObj::getDims, policy::move) - .def("printData", &TensorObj::printData, policy::automatic) + .def("cloneFloats", &TensorObj::cloneFloats, policy::move) .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic)