feat: 支持导出浮点向量

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 11:19:37 +08:00
parent cf9bdb0562
commit 3d122aebfe
2 changed files with 24 additions and 2 deletions

View File

@ -1,6 +1,11 @@
#pragma once
#include "core/tensor_base.h"
#include <cmath>
#include <cstring>
#if USE_CUDA
#include "cuda/cuda_runtime.h"
#endif
namespace infini {
@ -55,7 +60,6 @@ class TensorObj : public TensorBaseObj {
obj->outputOf.reset();
return obj;
}
// TODO: clarify whether clone copies data
Tensor clone(Runtime runtime) const {
auto obj = make_ref<TensorObj>(*this);
obj->runtime = runtime;
@ -68,6 +72,24 @@ class TensorObj : public TensorBaseObj {
}
return obj;
}
inline std::vector<float> cloneFloats() const {
IT_ASSERT(data != nullptr);
IT_ASSERT(getDType() == DataType::Float32);
std::vector<float> ans(size());
auto src = getRawDataPtr<void *>();
auto dst = ans.data();
auto bytes = getBytes();
if (runtime->isCpu()) {
memcpy(dst, src, bytes);
} else {
#if USE_CUDA
cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost);
#else
IT_TODO_HALT();
#endif
}
return ans;
}
void printData() const;
bool equalData(const Tensor &rhs) const;

View File

@ -146,7 +146,7 @@ void init_graph_builder(py::module &m) {
#endif
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("shape", &TensorObj::getDims, policy::move)
.def("printData", &TensorObj::printData, policy::automatic)
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)