forked from jiuyuan/InfiniTensor
feat: 支持导出浮点向量
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
cf9bdb0562
commit
3d122aebfe
|
@ -1,6 +1,11 @@
|
|||
#pragma once
|
||||
#include "core/tensor_base.h"
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#if USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -55,7 +60,6 @@ class TensorObj : public TensorBaseObj {
|
|||
obj->outputOf.reset();
|
||||
return obj;
|
||||
}
|
||||
// TODO: clarify whether clone copies data
|
||||
Tensor clone(Runtime runtime) const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->runtime = runtime;
|
||||
|
@ -68,6 +72,24 @@ class TensorObj : public TensorBaseObj {
|
|||
}
|
||||
return obj;
|
||||
}
|
||||
inline std::vector<float> cloneFloats() const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
IT_ASSERT(getDType() == DataType::Float32);
|
||||
std::vector<float> ans(size());
|
||||
auto src = getRawDataPtr<void *>();
|
||||
auto dst = ans.data();
|
||||
auto bytes = getBytes();
|
||||
if (runtime->isCpu()) {
|
||||
memcpy(dst, src, bytes);
|
||||
} else {
|
||||
#if USE_CUDA
|
||||
cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost);
|
||||
#else
|
||||
IT_TODO_HALT();
|
||||
#endif
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
void printData() const;
|
||||
bool equalData(const Tensor &rhs) const;
|
||||
|
|
|
@ -146,7 +146,7 @@ void init_graph_builder(py::module &m) {
|
|||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic)
|
||||
.def("cloneFloats", &TensorObj::cloneFloats, policy::move)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue