fix: 支持更多 numpy 类型

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-25 11:34:35 +08:00
parent 8ae83caa7c
commit 66ef2a9c61
2 changed files with 39 additions and 9 deletions

@ -1 +1 @@
Subproject commit eda160a4b5338cb812053af9fb09609504eee5bb
Subproject commit c1b42af5a46eed445c8a78d2af9d09138c7b9edc

View File

@ -30,7 +30,8 @@ namespace infini {
class Executor;
} // namespace infini
static std::string getFormat(refactor::common::DataType);
refactor::common::DataType parseNumpyDType(py::dtype const &);
pybind11::dtype buildNumpyDType(refactor::common::DataType);
namespace frontend {
using namespace refactor;
@ -75,8 +76,8 @@ class Compiler {
std::transform(tensor.shape.begin(), tensor.shape.end(), shape.begin(),
[](auto const &d) { return d.value(); });
auto ans = py::array(py::dtype(getFormat(tensor.dataType)),
std::move(shape), nullptr);
auto ans = py::array(buildNumpyDType(tensor.dataType), std::move(shape),
nullptr);
if (tensor.data) {
std::memcpy(ans.mutable_data(), tensor.data->ptr, ans.nbytes());
}
@ -126,7 +127,7 @@ class Executor {
std::transform(edge.tensor->shape.begin(), edge.tensor->shape.end(),
shape.begin(), [](auto const &d) { return d.value(); });
auto ans = py::array(py::dtype(getFormat(edge.tensor->dataType)),
auto ans = py::array(buildNumpyDType(edge.tensor->dataType),
std::move(shape), nullptr);
if (tensor) {
tensor->copyout(ans.mutable_data(), ans.nbytes());
@ -369,15 +370,39 @@ void registerPy(py::module &m) {
}
} // namespace infini
refactor::common::DataType parseNumpyDType(py::dtype const &dt) {
using namespace refactor::common;
#define CASE(T) \
if (dt.is(py::dtype::of<primitive_t<DataType::T>::type>())) { \
return DataType::T; \
}
CASE(F32);
CASE(F64);
CASE(I32);
CASE(I64);
CASE(I8);
CASE(I16);
CASE(U8);
CASE(U16);
CASE(U32);
CASE(U64);
CASE(Bool);
#undef CASE
RUNTIME_ERROR("unsupported data type.");
}
// A helper function that converts DataType to python format string
static std::string getFormat(refactor::common::DataType type) {
pybind11::dtype buildNumpyDType(refactor::common::DataType dt) {
using namespace refactor::common;
#define CASE(T) \
case DataType::T: \
return py::format_descriptor<primitive_t<DataType::T>::type>::format();
return py::dtype::of<primitive_t<DataType::T>::type>();
switch (dt.internal) {
switch (type.internal) {
CASE(F32);
CASE(F64);
CASE(I32);
@ -388,15 +413,20 @@ static std::string getFormat(refactor::common::DataType type) {
CASE(U16);
CASE(U32);
CASE(U64);
CASE(Bool);
case DataType::FP16:
case DataType::BF16:
// Python uses "e" for half precision float type code.
// Check the following link for more information.
// https://docs.python.org/3/library/struct.html#format-characters
return "e";
return py::dtype("e");
default:
RUNTIME_ERROR("unsupported data type.");
}
#undef CASE
}
PYBIND11_MODULE(backend, m) {