forked from jiuyuan/InfiniTensor
feat: 补充 DataType 类型
- 增加了 6 个代数类型,与 onnx 的序号对应 - 现在可以导入 reshape 了 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
d9e2953425
commit
a7e58bd8d0
|
@ -4,10 +4,19 @@ namespace infini {
|
||||||
|
|
||||||
class DataType {
|
class DataType {
|
||||||
public:
|
public:
|
||||||
|
// legacy
|
||||||
static const DataType Float32;
|
static const DataType Float32;
|
||||||
static const DataType UInt32;
|
static const DataType UInt32;
|
||||||
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)};
|
// 这一组恰好与 onnx 的类型对齐:
|
||||||
static constexpr std::string_view names[]{"Float32", "UInt32"};
|
// <https://onnx.ai/onnx/intro/concepts.html#element-type>
|
||||||
|
static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64;
|
||||||
|
static constexpr size_t sizePerElement[]{
|
||||||
|
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
|
||||||
|
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)};
|
||||||
|
|
||||||
|
static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8",
|
||||||
|
"Int8", "UInt16", "Int16",
|
||||||
|
"Int32", "Int64"};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int index;
|
int index;
|
||||||
|
@ -29,9 +38,18 @@ class DataType {
|
||||||
|
|
||||||
inline const DataType DataType::Float32(0);
|
inline const DataType DataType::Float32(0);
|
||||||
inline const DataType DataType::UInt32(1);
|
inline const DataType DataType::UInt32(1);
|
||||||
|
inline const DataType DataType::UInt8(2), DataType::Int8(3),
|
||||||
|
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
|
||||||
|
DataType::Int64(7);
|
||||||
// Method definitions are out of the declaration due to GCC bug:
|
// Method definitions are out of the declaration due to GCC bug:
|
||||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||||
template <> inline DataType DataType::get<float>() { return Float32; }
|
template <> inline DataType DataType::get<float>() { return Float32; }
|
||||||
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
||||||
|
template <> inline DataType DataType::get<uint8_t>() { return UInt8; }
|
||||||
|
template <> inline DataType DataType::get<int8_t>() { return Int8; }
|
||||||
|
template <> inline DataType DataType::get<uint16_t>() { return UInt16; }
|
||||||
|
template <> inline DataType DataType::get<int16_t>() { return Int16; }
|
||||||
|
template <> inline DataType DataType::get<int32_t>() { return Int32; }
|
||||||
|
template <> inline DataType DataType::get<int64_t>() { return Int64; }
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -119,7 +119,7 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.output[0]] = handler.reshape(
|
tensors[node.output[0]] = handler.reshape(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
data[node.input[1]].int32_data,
|
[int(i) for i in data[node.input[1]].int64_data],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
import os, onnx, unittest
|
import os, onnx, unittest
|
||||||
from onnx import TensorProto
|
from onnx import TensorProto
|
||||||
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
|
from onnx.helper import (
|
||||||
|
make_model,
|
||||||
|
make_node,
|
||||||
|
make_tensor,
|
||||||
|
make_graph,
|
||||||
|
make_tensor_value_info,
|
||||||
|
)
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
||||||
|
|
||||||
|
@ -130,16 +136,20 @@ class TestStringMethods(unittest.TestCase):
|
||||||
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||||
|
|
||||||
def test_reshape(self):
|
def test_reshape(self):
|
||||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||||
shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
# shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨
|
||||||
reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
# 不知道怎么把 ValueInfoProto 转换成 TensorProto
|
||||||
|
shape = make_tensor_value_info("shape", TensorProto.INT64, [4])
|
||||||
|
shape_data = make_tensor("shape", TensorProto.INT64, [4], [3, 2, 4, 3])
|
||||||
|
reshaped = make_tensor_value_info(
|
||||||
|
"reshaped", TensorProto.FLOAT, shape_data.int64_data
|
||||||
|
)
|
||||||
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
||||||
# FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。
|
# 可以构造一个 shape 只出现在 initializer 里而不出现在 input 里的图,
|
||||||
# tensor 的类型又不支持 INT64,所以这里会报一个错。
|
# 但实际上的图中 initializer 里的必然会出现在 input 里,不知道为什么这样设计
|
||||||
# 如何分辨 onnx 的张量是不是需要作为张量注册?
|
make_and_import_model(
|
||||||
# make_and_import_model(
|
make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data])
|
||||||
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
)
|
||||||
# )
|
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
|
|
|
@ -88,9 +88,7 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
||||||
std::move(shape));
|
std::move(shape));
|
||||||
return reshaped;
|
return reshaped;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g->addOp<ReshapeObj>(std::move(data), reshaped, std::move(shape))
|
||||||
->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
|
|
||||||
std::move(shape))
|
|
||||||
->getOutput();
|
->getOutput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -101,6 +99,18 @@ static DataType dtype_repr_convert(int dtype) {
|
||||||
return DataType::Float32;
|
return DataType::Float32;
|
||||||
case OnnxDType::UINT32:
|
case OnnxDType::UINT32:
|
||||||
return DataType::UInt32;
|
return DataType::UInt32;
|
||||||
|
case OnnxDType::UINT8:
|
||||||
|
return DataType::UInt8;
|
||||||
|
case OnnxDType::INT8:
|
||||||
|
return DataType::Int8;
|
||||||
|
case OnnxDType::UINT16:
|
||||||
|
return DataType::UInt16;
|
||||||
|
case OnnxDType::INT16:
|
||||||
|
return DataType::Int16;
|
||||||
|
case OnnxDType::INT32:
|
||||||
|
return DataType::Int32;
|
||||||
|
case OnnxDType::INT64:
|
||||||
|
return DataType::Int64;
|
||||||
default:
|
default:
|
||||||
IT_ASSERT(false, "Unsupported data type");
|
IT_ASSERT(false, "Unsupported data type");
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
#include <filesystem>
|
|
||||||
#include <pybind11/embed.h>
|
|
||||||
#include <test.h>
|
|
||||||
|
|
||||||
TEST(Python, pybind) {
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
namespace py = pybind11;
|
|
||||||
using mod = py::module;
|
|
||||||
|
|
||||||
py::scoped_interpreter _python;
|
|
||||||
|
|
||||||
auto sys_path_append = mod::import("sys").attr("path").attr("append");
|
|
||||||
sys_path_append(fs::path(__FILE__).parent_path().c_str());
|
|
||||||
auto ans = mod::import("python").attr("inc")(1);
|
|
||||||
EXPECT_EQ(ans.cast<int>(), 2);
|
|
||||||
}
|
|
Loading…
Reference in New Issue