forked from jiuyuan/InfiniTensor
feat: 补充 DataType 类型
- 增加了 6 个代数类型,与 onnx 的序号对应 - 现在可以导入 reshape 了 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
d9e2953425
commit
a7e58bd8d0
|
@ -4,10 +4,19 @@ namespace infini {
|
|||
|
||||
class DataType {
|
||||
public:
|
||||
// legacy
|
||||
static const DataType Float32;
|
||||
static const DataType UInt32;
|
||||
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)};
|
||||
static constexpr std::string_view names[]{"Float32", "UInt32"};
|
||||
// 这一组恰好与 onnx 的类型对齐:
|
||||
// <https://onnx.ai/onnx/intro/concepts.html#element-type>
|
||||
static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64;
|
||||
static constexpr size_t sizePerElement[]{
|
||||
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
|
||||
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)};
|
||||
|
||||
static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8",
|
||||
"Int8", "UInt16", "Int16",
|
||||
"Int32", "Int64"};
|
||||
|
||||
private:
|
||||
int index;
|
||||
|
@ -29,9 +38,18 @@ class DataType {
|
|||
|
||||
inline const DataType DataType::Float32(0);
|
||||
inline const DataType DataType::UInt32(1);
|
||||
inline const DataType DataType::UInt8(2), DataType::Int8(3),
|
||||
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
|
||||
DataType::Int64(7);
|
||||
// Method definitions are out of the declaration due to GCC bug:
|
||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||
template <> inline DataType DataType::get<float>() { return Float32; }
|
||||
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
||||
template <> inline DataType DataType::get<uint8_t>() { return UInt8; }
|
||||
template <> inline DataType DataType::get<int8_t>() { return Int8; }
|
||||
template <> inline DataType DataType::get<uint16_t>() { return UInt16; }
|
||||
template <> inline DataType DataType::get<int16_t>() { return Int16; }
|
||||
template <> inline DataType DataType::get<int32_t>() { return Int32; }
|
||||
template <> inline DataType DataType::get<int64_t>() { return Int64; }
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -119,7 +119,7 @@ def from_onnx(model: onnx.ModelProto):
|
|||
tensors[node.output[0]] = handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
data[node.input[1]].int32_data,
|
||||
[int(i) for i in data[node.input[1]].int64_data],
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
import os, onnx, unittest
|
||||
from onnx import TensorProto
|
||||
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
|
||||
from onnx.helper import (
|
||||
make_model,
|
||||
make_node,
|
||||
make_tensor,
|
||||
make_graph,
|
||||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime
|
||||
|
||||
|
@ -130,16 +136,20 @@ class TestStringMethods(unittest.TestCase):
|
|||
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||
|
||||
def test_reshape(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||
shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
||||
reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4])
|
||||
# shape 对于后端来说并不是一个张量,然而转换中可能没有办法分辨
|
||||
# 不知道怎么把 ValueInfoProto 转换成 TensorProto
|
||||
shape = make_tensor_value_info("shape", TensorProto.INT64, [4])
|
||||
shape_data = make_tensor("shape", TensorProto.INT64, [4], [3, 2, 4, 3])
|
||||
reshaped = make_tensor_value_info(
|
||||
"reshaped", TensorProto.FLOAT, shape_data.int64_data
|
||||
)
|
||||
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
||||
# FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。
|
||||
# tensor 的类型又不支持 INT64,所以这里会报一个错。
|
||||
# 如何分辨 onnx 的张量是不是需要作为张量注册?
|
||||
# make_and_import_model(
|
||||
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
||||
# )
|
||||
# 可以构造一个 shape 只出现在 initializer 里而不出现在 input 里的图,
|
||||
# 但实际上的图中 initializer 里的必然会出现在 input 里,不知道为什么这样设计
|
||||
make_and_import_model(
|
||||
make_graph([reshape], "reshape", [data, shape], [reshaped], [shape_data])
|
||||
)
|
||||
|
||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||
def test_linear(self):
|
||||
|
|
|
@ -88,9 +88,7 @@ Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
|||
std::move(shape));
|
||||
return reshaped;
|
||||
} else {
|
||||
return g
|
||||
->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
|
||||
std::move(shape))
|
||||
return g->addOp<ReshapeObj>(std::move(data), reshaped, std::move(shape))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +99,18 @@ static DataType dtype_repr_convert(int dtype) {
|
|||
return DataType::Float32;
|
||||
case OnnxDType::UINT32:
|
||||
return DataType::UInt32;
|
||||
case OnnxDType::UINT8:
|
||||
return DataType::UInt8;
|
||||
case OnnxDType::INT8:
|
||||
return DataType::Int8;
|
||||
case OnnxDType::UINT16:
|
||||
return DataType::UInt16;
|
||||
case OnnxDType::INT16:
|
||||
return DataType::Int16;
|
||||
case OnnxDType::INT32:
|
||||
return DataType::Int32;
|
||||
case OnnxDType::INT64:
|
||||
return DataType::Int64;
|
||||
default:
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
#include <filesystem>
|
||||
#include <pybind11/embed.h>
|
||||
#include <test.h>
|
||||
|
||||
TEST(Python, pybind) {
|
||||
namespace fs = std::filesystem;
|
||||
namespace py = pybind11;
|
||||
using mod = py::module;
|
||||
|
||||
py::scoped_interpreter _python;
|
||||
|
||||
auto sys_path_append = mod::import("sys").attr("path").attr("append");
|
||||
sys_path_append(fs::path(__FILE__).parent_path().c_str());
|
||||
auto ans = mod::import("python").attr("inc")(1);
|
||||
EXPECT_EQ(ans.cast<int>(), 2);
|
||||
}
|
Loading…
Reference in New Issue