feat: 补充 DataType 类型

- 增加了 6 个代数类型,与 onnx 的序号对应
- 现在可以导入 reshape 了

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 11:27:57 +08:00
parent d9e2953425
commit a7e58bd8d0
5 changed files with 55 additions and 33 deletions

View File

@ -4,10 +4,19 @@ namespace infini {
class DataType { class DataType {
public: public:
// legacy
static const DataType Float32; static const DataType Float32;
static const DataType UInt32; static const DataType UInt32;
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)}; // 这一组恰好与 onnx 的类型对齐:
static constexpr std::string_view names[]{"Float32", "UInt32"}; // <https://onnx.ai/onnx/intro/concepts.html#element-type>
static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64;
static constexpr size_t sizePerElement[]{
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)};
static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8",
"Int8", "UInt16", "Int16",
"Int32", "Int64"};
private: private:
int index; int index;
@ -29,9 +38,18 @@ class DataType {
inline const DataType DataType::Float32(0); inline const DataType DataType::Float32(0);
inline const DataType DataType::UInt32(1); inline const DataType DataType::UInt32(1);
inline const DataType DataType::UInt8(2), DataType::Int8(3),
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
DataType::Int64(7);
// Method definitions are out of the declaration due to GCC bug: // Method definitions are out of the declaration due to GCC bug:
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
template <> inline DataType DataType::get<float>() { return Float32; } template <> inline DataType DataType::get<float>() { return Float32; }
template <> inline DataType DataType::get<uint32_t>() { return UInt32; } template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
template <> inline DataType DataType::get<uint8_t>() { return UInt8; }
template <> inline DataType DataType::get<int8_t>() { return Int8; }
template <> inline DataType DataType::get<uint16_t>() { return UInt16; }
template <> inline DataType DataType::get<int16_t>() { return Int16; }
template <> inline DataType DataType::get<int32_t>() { return Int32; }
template <> inline DataType DataType::get<int64_t>() { return Int64; }
} // namespace infini } // namespace infini

View File

@ -119,7 +119,7 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.output[0]] = handler.reshape( tensors[node.output[0]] = handler.reshape(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
data[node.input[1]].int32_data, [int(i) for i in data[node.input[1]].int64_data],
) )
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))

View File

@ -1,6 +1,12 @@
import os, onnx, unittest import os, onnx, unittest
from onnx import TensorProto from onnx import TensorProto
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info from onnx.helper import (
make_model,
make_node,
make_tensor,
make_graph,
make_tensor_value_info,
)
from onnx.checker import check_model from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
@ -130,16 +136,20 @@ class TestStringMethods(unittest.TestCase):
make_and_import_model(make_graph([flatten], "flatten", [x], [y])) make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
def test_reshape(self): def test_reshape(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8]) # shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨
reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8]) # 不知道怎么把 ValueInfoProto 转换成 TensorProto
shape = make_tensor_value_info("shape", TensorProto.INT64, [4])
shape_data = make_tensor("shape", TensorProto.INT64, [4], [3, 2, 4, 3])
reshaped = make_tensor_value_info(
"reshaped", TensorProto.FLOAT, shape_data.int64_data
)
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape") reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
# FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。 # 可以构造一个 shape 只出现在 initializer 里而不出现在 input 里的图,
# tensor 的类型又不支持 INT64所以这里会报一个错。 # 但实际上的图中 initializer 里的必然会出现在 input 里,不知道为什么这样设计
# 如何分辨 onnx 的张量是不是需要作为张量注册? make_and_import_model(
# make_and_import_model( make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data])
# make_graph([reshape], "reshape", [data, shape], [reshaped]) )
# )
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression> # see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self): def test_linear(self):

View File

@ -88,9 +88,7 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
std::move(shape)); std::move(shape));
return reshaped; return reshaped;
} else { } else {
return g return g->addOp<ReshapeObj>(std::move(data), reshaped, std::move(shape))
->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
std::move(shape))
->getOutput(); ->getOutput();
} }
} }
@ -101,6 +99,18 @@ static DataType dtype_repr_convert(int dtype) {
return DataType::Float32; return DataType::Float32;
case OnnxDType::UINT32: case OnnxDType::UINT32:
return DataType::UInt32; return DataType::UInt32;
case OnnxDType::UINT8:
return DataType::UInt8;
case OnnxDType::INT8:
return DataType::Int8;
case OnnxDType::UINT16:
return DataType::UInt16;
case OnnxDType::INT16:
return DataType::Int16;
case OnnxDType::INT32:
return DataType::Int32;
case OnnxDType::INT64:
return DataType::Int64;
default: default:
IT_ASSERT(false, "Unsupported data type"); IT_ASSERT(false, "Unsupported data type");
} }

View File

@ -1,16 +0,0 @@
#include <filesystem>
#include <pybind11/embed.h>
#include <test.h>
TEST(Python, pybind) {
namespace fs = std::filesystem;
namespace py = pybind11;
using mod = py::module;
py::scoped_interpreter _python;
auto sys_path_append = mod::import("sys").attr("path").attr("append");
sys_path_append(fs::path(__FILE__).parent_path().c_str());
auto ans = mod::import("python").attr("inc")(1);
EXPECT_EQ(ans.cast<int>(), 2);
}