forked from jiuyuan/InfiniTensor
fix: 支持更多 numpy 类型
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
8ae83caa7c
commit
66ef2a9c61
|
@ -1 +1 @@
|
|||
Subproject commit eda160a4b5338cb812053af9fb09609504eee5bb
|
||||
Subproject commit c1b42af5a46eed445c8a78d2af9d09138c7b9edc
|
|
@ -30,7 +30,8 @@ namespace infini {
|
|||
class Executor;
|
||||
} // namespace infini
|
||||
|
||||
static std::string getFormat(refactor::common::DataType);
|
||||
refactor::common::DataType parseNumpyDType(py::dtype const &);
|
||||
pybind11::dtype buildNumpyDType(refactor::common::DataType);
|
||||
|
||||
namespace frontend {
|
||||
using namespace refactor;
|
||||
|
@ -75,8 +76,8 @@ class Compiler {
|
|||
std::transform(tensor.shape.begin(), tensor.shape.end(), shape.begin(),
|
||||
[](auto const &d) { return d.value(); });
|
||||
|
||||
auto ans = py::array(py::dtype(getFormat(tensor.dataType)),
|
||||
std::move(shape), nullptr);
|
||||
auto ans = py::array(buildNumpyDType(tensor.dataType), std::move(shape),
|
||||
nullptr);
|
||||
if (tensor.data) {
|
||||
std::memcpy(ans.mutable_data(), tensor.data->ptr, ans.nbytes());
|
||||
}
|
||||
|
@ -126,7 +127,7 @@ class Executor {
|
|||
std::transform(edge.tensor->shape.begin(), edge.tensor->shape.end(),
|
||||
shape.begin(), [](auto const &d) { return d.value(); });
|
||||
|
||||
auto ans = py::array(py::dtype(getFormat(edge.tensor->dataType)),
|
||||
auto ans = py::array(buildNumpyDType(edge.tensor->dataType),
|
||||
std::move(shape), nullptr);
|
||||
if (tensor) {
|
||||
tensor->copyout(ans.mutable_data(), ans.nbytes());
|
||||
|
@ -369,15 +370,39 @@ void registerPy(py::module &m) {
|
|||
}
|
||||
} // namespace infini
|
||||
|
||||
refactor::common::DataType parseNumpyDType(py::dtype const &dt) {
|
||||
using namespace refactor::common;
|
||||
#define CASE(T) \
|
||||
if (dt.is(py::dtype::of<primitive_t<DataType::T>::type>())) { \
|
||||
return DataType::T; \
|
||||
}
|
||||
|
||||
CASE(F32);
|
||||
CASE(F64);
|
||||
CASE(I32);
|
||||
CASE(I64);
|
||||
CASE(I8);
|
||||
CASE(I16);
|
||||
CASE(U8);
|
||||
CASE(U16);
|
||||
CASE(U32);
|
||||
CASE(U64);
|
||||
CASE(Bool);
|
||||
|
||||
#undef CASE
|
||||
RUNTIME_ERROR("unsupported data type.");
|
||||
}
|
||||
|
||||
// A helper function that converts DataType to python format string
|
||||
static std::string getFormat(refactor::common::DataType type) {
|
||||
pybind11::dtype buildNumpyDType(refactor::common::DataType dt) {
|
||||
using namespace refactor::common;
|
||||
|
||||
#define CASE(T) \
|
||||
case DataType::T: \
|
||||
return py::format_descriptor<primitive_t<DataType::T>::type>::format();
|
||||
return py::dtype::of<primitive_t<DataType::T>::type>();
|
||||
|
||||
switch (dt.internal) {
|
||||
|
||||
switch (type.internal) {
|
||||
CASE(F32);
|
||||
CASE(F64);
|
||||
CASE(I32);
|
||||
|
@ -388,15 +413,20 @@ static std::string getFormat(refactor::common::DataType type) {
|
|||
CASE(U16);
|
||||
CASE(U32);
|
||||
CASE(U64);
|
||||
CASE(Bool);
|
||||
|
||||
case DataType::FP16:
|
||||
case DataType::BF16:
|
||||
// Python uses "e" for half precision float type code.
|
||||
// Check the following link for more information.
|
||||
// https://docs.python.org/3/library/struct.html#format-characters
|
||||
return "e";
|
||||
return py::dtype("e");
|
||||
|
||||
default:
|
||||
RUNTIME_ERROR("unsupported data type.");
|
||||
}
|
||||
|
||||
#undef CASE
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(backend, m) {
|
||||
|
|
Loading…
Reference in New Issue