diff --git a/include/core/data_type.h b/include/core/data_type.h index 654ce1ce..282544ea 100644 --- a/include/core/data_type.h +++ b/include/core/data_type.h @@ -4,10 +4,19 @@ namespace infini { class DataType { public: + // legacy static const DataType Float32; static const DataType UInt32; - static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)}; - static constexpr std::string_view names[]{"Float32", "UInt32"}; + // 这一组恰好与 onnx 的类型对齐: + // + static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64; + static constexpr size_t sizePerElement[]{ + sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t), + sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)}; + + static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8", + "Int8", "UInt16", "Int16", + "Int32", "Int64"}; private: int index; @@ -29,9 +38,18 @@ class DataType { inline const DataType DataType::Float32(0); inline const DataType DataType::UInt32(1); +inline const DataType DataType::UInt8(2), DataType::Int8(3), + DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6), + DataType::Int64(7); // Method definitions are out of the declaration due to GCC bug: // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc template <> inline DataType DataType::get() { return Float32; } template <> inline DataType DataType::get() { return UInt32; } +template <> inline DataType DataType::get() { return UInt8; } +template <> inline DataType DataType::get() { return Int8; } +template <> inline DataType DataType::get() { return UInt16; } +template <> inline DataType DataType::get() { return Int16; } +template <> inline DataType DataType::get() { return Int32; } +template <> inline DataType DataType::get() { return Int64; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 453185ae..bbbaaa17 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -119,7 +119,7 @@ def from_onnx(model: onnx.ModelProto): tensors[node.output[0]] = handler.reshape( tensors[node.input[0]], tensors.get(node.output[0]), - data[node.input[1]].int32_data, + [int(i) for i in data[node.input[1]].int64_data], ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 092b1e05..7e36b125 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -1,6 +1,12 @@ import os, onnx, unittest from onnx import TensorProto -from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info +from onnx.helper import ( + make_model, + make_node, + make_tensor, + make_graph, + make_tensor_value_info, +) from onnx.checker import check_model from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime @@ -130,16 +136,20 @@ class TestStringMethods(unittest.TestCase): make_and_import_model(make_graph([flatten], "flatten", [x], [y])) def test_reshape(self): - data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) - shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8]) - reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8]) + data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) + # shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨 + # 不知道怎么把 ValueInfoProto 转换成 TensorProto + shape = make_tensor_value_info("shape", TensorProto.INT64, [4]) + shape_data = make_tensor("shape", TensorProto.INT64, [4], [3, 2, 4, 3]) + reshaped = make_tensor_value_info( + "reshaped", TensorProto.FLOAT, shape_data.int64_data + ) reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape") - # FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。 - # tensor 的类型又不支持 INT64,所以这里会报一个错。 - # 如何分辨 onnx 的张量是不是需要作为张量注册? - # make_and_import_model( - # make_graph([reshape], "reshape", [data, shape], [reshaped]) - # ) + # 可以构造一个 shape 只出现在 initializer 里而不出现在 input 里的图, + # 但实际上的图中 initializer 里的必然会出现在 input 里,不知道为什么这样设计 + make_and_import_model( + make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data]) + ) # see def test_linear(self): diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 80e88a09..cd3b355d 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -88,9 +88,7 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) { std::move(shape)); return reshaped; } else { - return g - ->addOpWithOutputs(std::move(data), reshaped, - std::move(shape)) + return g->addOp(std::move(data), reshaped, std::move(shape)) ->getOutput(); } } @@ -101,6 +99,18 @@ static DataType dtype_repr_convert(int dtype) { return DataType::Float32; case OnnxDType::UINT32: return DataType::UInt32; + case OnnxDType::UINT8: + return DataType::UInt8; + case OnnxDType::INT8: + return DataType::Int8; + case OnnxDType::UINT16: + return DataType::UInt16; + case OnnxDType::INT16: + return DataType::Int16; + case OnnxDType::INT32: + return DataType::Int32; + case OnnxDType::INT64: + return DataType::Int64; default: IT_ASSERT(false, "Unsupported data type"); } diff --git a/test/core/test_python.cc b/test/core/test_python.cc deleted file mode 100644 index a3933387..00000000 --- a/test/core/test_python.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include -#include -#include - -TEST(Python, pybind) { - namespace fs = std::filesystem; - namespace py = pybind11; - using mod = py::module; - - py::scoped_interpreter _python; - - auto sys_path_append = mod::import("sys").attr("path").attr("append"); - sys_path_append(fs::path(__FILE__).parent_path().c_str()); - auto ans = mod::import("python").attr("inc")(1); - EXPECT_EQ(ans.cast(), 2); -}