forked from jiuyuan/InfiniTensor
fix: 按照review意见修改
This commit is contained in:
parent
c4aec6c38b
commit
e936ee2850
|
@ -5,30 +5,43 @@ namespace infini {
|
|||
class DataType {
|
||||
public:
|
||||
// <https://onnx.ai/onnx/intro/concepts.html#element-type>
|
||||
static const DataType Undefine;
|
||||
static const DataType Float32;
|
||||
static const DataType UInt32;
|
||||
static const DataType UInt8;
|
||||
static const DataType Int8;
|
||||
static const DataType UInt16;
|
||||
static const DataType Int16;
|
||||
static const DataType Int32;
|
||||
static const DataType Int64;
|
||||
static const DataType String;
|
||||
static const DataType Bool;
|
||||
static const DataType Float16;
|
||||
static const DataType Double;
|
||||
static const DataType UInt32;
|
||||
static const DataType UInt64;
|
||||
// "sizePerElement" show the DType to cpu_type
|
||||
// DataType::Bool -> int8_t DataType::Float16 -> uint8_t
|
||||
static constexpr size_t sizePerElement[]{
|
||||
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
|
||||
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t),
|
||||
sizeof(int8_t), sizeof(uint8_t)};
|
||||
// DataType::Bool -> int8_t DataType::Float16 -> uint16_t
|
||||
static constexpr size_t sizePerElement[]{0,
|
||||
sizeof(float),
|
||||
sizeof(uint8_t),
|
||||
sizeof(int8_t),
|
||||
sizeof(uint16_t),
|
||||
sizeof(int16_t),
|
||||
sizeof(int32_t),
|
||||
sizeof(int64_t),
|
||||
sizeof(std::string),
|
||||
sizeof(int8_t),
|
||||
sizeof(uint16_t),
|
||||
sizeof(double),
|
||||
sizeof(uint32_t),
|
||||
sizeof(uint64_t)};
|
||||
|
||||
static constexpr std::string_view names[]{
|
||||
"Float32", "UInt32", "UInt8", "Int8", "UInt16",
|
||||
"Int16", "Int32", "Int64", "Bool", "Float16"};
|
||||
"Undefine", "Float32", "UInt8", "Int8", "UInt16",
|
||||
"Int16", "Int32", "Int64", "String", "Bool",
|
||||
"Float16", "Double", "UInt32", "UInt64"};
|
||||
|
||||
static constexpr std::string_view cpuType[]{
|
||||
"float", "uint32_t", "uint8_t", "int8_t", "uint16_t",
|
||||
"int16_t", "int32_t", "int64_t", "int8_t", "uint8_t"};
|
||||
static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8};
|
||||
|
||||
private:
|
||||
int index;
|
||||
|
@ -41,40 +54,58 @@ class DataType {
|
|||
bool operator==(const DataType &rhs) const { return index == rhs.index; }
|
||||
bool operator<(const DataType &rhs) const { return index < rhs.index; }
|
||||
|
||||
template <typename T> static std::string get() {
|
||||
template <typename T> static int get() {
|
||||
IT_TODO_HALT_MSG("Unsupported data type");
|
||||
}
|
||||
size_t getSize() const { return sizePerElement[index]; }
|
||||
string toString() const { return string(names[index]); }
|
||||
string cpuTypeString() const { return string(cpuType[index]); }
|
||||
int cpuTypeInt() const { return cpuType[index]; }
|
||||
int getIndex() const { return index; }
|
||||
};
|
||||
|
||||
inline const DataType DataType::Float32(0);
|
||||
inline const DataType DataType::UInt32(1);
|
||||
inline const DataType DataType::UInt8(2), DataType::Int8(3),
|
||||
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
|
||||
DataType::Int64(7), DataType::Bool(8), DataType::Float16(9);
|
||||
// to be consistent with onnx
|
||||
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
|
||||
inline const DataType DataType::Undefine(0);
|
||||
inline const DataType DataType::Float32(1);
|
||||
inline const DataType DataType::UInt8(2);
|
||||
inline const DataType DataType::Int8(3);
|
||||
inline const DataType DataType::UInt16(4);
|
||||
inline const DataType DataType::Int16(5);
|
||||
inline const DataType DataType::Int32(6);
|
||||
inline const DataType DataType::Int64(7);
|
||||
inline const DataType DataType::String(8);
|
||||
inline const DataType DataType::Bool(9);
|
||||
inline const DataType DataType::Float16(10);
|
||||
inline const DataType DataType::Double(11);
|
||||
inline const DataType DataType::UInt32(12);
|
||||
inline const DataType DataType::UInt64(13);
|
||||
// Method definitions are out of the declaration due to GCC bug:
|
||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||
template <> inline std::string DataType::get<float>() { return "float"; }
|
||||
template <> inline std::string DataType::get<uint32_t>() { return "uint32_t"; }
|
||||
template <> inline std::string DataType::get<uint8_t>() { return "uint8_t"; }
|
||||
template <> inline std::string DataType::get<int8_t>() { return "int8_t"; }
|
||||
template <> inline std::string DataType::get<uint16_t>() { return "uint16_t"; }
|
||||
template <> inline std::string DataType::get<int16_t>() { return "int16_t"; }
|
||||
template <> inline std::string DataType::get<int32_t>() { return "int32_t"; }
|
||||
template <> inline std::string DataType::get<int64_t>() { return "int64_t"; }
|
||||
template <> inline int DataType::get<float>() { return 0; }
|
||||
template <> inline int DataType::get<uint32_t>() { return 1; }
|
||||
template <> inline int DataType::get<uint8_t>() { return 2; }
|
||||
template <> inline int DataType::get<int8_t>() { return 3; }
|
||||
template <> inline int DataType::get<uint16_t>() { return 4; }
|
||||
template <> inline int DataType::get<int16_t>() { return 5; }
|
||||
template <> inline int DataType::get<int32_t>() { return 6; }
|
||||
template <> inline int DataType::get<int64_t>() { return 7; }
|
||||
template <> inline int DataType::get<uint64_t>() { return 8; }
|
||||
template <> inline int DataType::get<double>() { return 9; }
|
||||
|
||||
template <int index> struct DT {};
|
||||
template <> struct DT<0> { using t = float; };
|
||||
template <> struct DT<1> { using t = uint32_t; };
|
||||
template <> struct DT<0> { using t = bool; };
|
||||
template <> struct DT<1> { using t = float; };
|
||||
template <> struct DT<2> { using t = uint8_t; };
|
||||
template <> struct DT<3> { using t = int8_t; };
|
||||
template <> struct DT<4> { using t = uint16_t; };
|
||||
template <> struct DT<5> { using t = int16_t; };
|
||||
template <> struct DT<6> { using t = int32_t; };
|
||||
template <> struct DT<7> { using t = int64_t; };
|
||||
template <> struct DT<8> { using t = int8_t; };
|
||||
template <> struct DT<9> { using t = uint16_t; };
|
||||
template <> struct DT<8> { using t = char; };
|
||||
template <> struct DT<9> { using t = int8_t; };
|
||||
template <> struct DT<10> { using t = uint16_t; };
|
||||
template <> struct DT<11> { using t = double; };
|
||||
template <> struct DT<12> { using t = uint32_t; };
|
||||
template <> struct DT<13> { using t = uint64_t; };
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -7,30 +7,6 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
// Use the indices from onnx to reduce delivery overhead,
|
||||
// which comes from onnx but may be not only used for onnx.
|
||||
//
|
||||
// see https://onnx.ai/onnx/intro/concepts.html#element-type
|
||||
enum OnnxDType : int {
|
||||
UNDEFINED = 0,
|
||||
FLOAT,
|
||||
UINT8,
|
||||
INT8,
|
||||
UINT16,
|
||||
INT16,
|
||||
INT32,
|
||||
INT64,
|
||||
STRING,
|
||||
BOOL,
|
||||
FLOAT16,
|
||||
DOUBLE,
|
||||
UINT32,
|
||||
UINT64,
|
||||
COMPLEX64,
|
||||
COMPLEX128,
|
||||
BFLOAT16,
|
||||
};
|
||||
|
||||
class GraphHandlerObj {
|
||||
Graph g;
|
||||
|
||||
|
|
|
@ -32,10 +32,7 @@ class TensorObj : public TensorBaseObj {
|
|||
string toString() const override;
|
||||
|
||||
size_t size() const { return _size; }
|
||||
size_t getBytes() const {
|
||||
size_t usebytes = _size * dtype.getSize();
|
||||
return dtype == DataType::Float16 ? usebytes * 2 : usebytes;
|
||||
}
|
||||
size_t getBytes() const { return _size * dtype.getSize(); }
|
||||
|
||||
Shape getDims() const { return shape; }
|
||||
vector<size_t> getStride() const;
|
||||
|
@ -48,24 +45,20 @@ class TensorObj : public TensorBaseObj {
|
|||
|
||||
// Copy elements from `data`.
|
||||
template <typename T> void copyin(const vector<T> &data) {
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
IT_ASSERT(data.size() >= _size);
|
||||
copyin(data.data(), getBytes());
|
||||
}
|
||||
// Copy all the elements to a vector.
|
||||
template <typename T> auto copyout() const {
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
|
||||
auto sizeofvec = _size;
|
||||
if (dtype == DataType::Float16) {
|
||||
sizeofvec = _size * 2;
|
||||
}
|
||||
std::vector<T> ans(sizeofvec);
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
std::vector<T> ans(_size);
|
||||
copyout(ans.data(), getBytes());
|
||||
return ans;
|
||||
}
|
||||
// Copy the element at `pos`.
|
||||
template <typename T> auto copyOne(const vector<int> &pos) const {
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
auto offset = getOffset(pos);
|
||||
auto bytes = dtype.getSize();
|
||||
T ans;
|
||||
|
@ -105,7 +98,7 @@ class TensorObj : public TensorBaseObj {
|
|||
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
|
||||
|
||||
template <typename T> bool equalData(const vector<T> &dataVector) {
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
IT_ASSERT(size() == dataVector.size());
|
||||
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import backend
|
||||
import struct
|
||||
from onnx import (
|
||||
ModelProto,
|
||||
TensorProto,
|
||||
|
@ -477,7 +478,10 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
# NOTE(constroy): `axes` is an attribute until opset version 13.
|
||||
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
|
||||
next(
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||
None,
|
||||
),
|
||||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||
!= 0,
|
||||
)
|
||||
|
@ -534,17 +538,7 @@ class OnnxStub:
|
|||
elif tensor.data_type == TensorProto.BOOL:
|
||||
obj.copyin_int8(_parse_data(tensor))
|
||||
elif tensor.data_type == TensorProto.FLOAT16:
|
||||
if len(tensor.int32_data) != 0:
|
||||
list_int32_data = []
|
||||
for element_data in tensor.int32_data:
|
||||
element_byte = element_data.to_bytes(2, "little")
|
||||
list_int32_data.append(element_byte[0])
|
||||
list_int32_data.append(element_byte[1])
|
||||
obj.copyin_uint8(list_int32_data)
|
||||
elif len(tensor.raw_data) != 0:
|
||||
obj.copyin_uint8(list(tensor.raw_data))
|
||||
else :
|
||||
raise Exception("Tensor have no float16 data!")
|
||||
obj.copyin_float16(_parse_data_fp16(tensor))
|
||||
elif tensor.data_type == TensorProto.INT8:
|
||||
obj.copyin_uint8(_parse_data(tensor))
|
||||
else:
|
||||
|
@ -910,5 +904,22 @@ def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[st
|
|||
def _parse_data(tensor: TensorProto) -> List[Any]:
|
||||
return to_array(tensor).flatten().tolist()
|
||||
|
||||
|
||||
def _parse_data_fp16(tensor: TensorProto):
|
||||
list_ = []
|
||||
if len(tensor.int32_data) != 0:
|
||||
for element_data in tensor.int32_data:
|
||||
element_byte = element_data.to_bytes(2, "little")
|
||||
list_.append(element_byte[0] + element_byte[1] * 256)
|
||||
elif len(tensor.raw_data) != 0:
|
||||
list_raw_data = list(tensor.raw_data)
|
||||
list_data = [list_raw_data[i : i + 2] for i in range(0, len(list_raw_data), 2)]
|
||||
for ele in list_data:
|
||||
list_.append(ele[0] + ele[1] * 256)
|
||||
else:
|
||||
raise Exception("Tensor have no float16 data!")
|
||||
return list_
|
||||
|
||||
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
|
|
@ -292,27 +292,35 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
|
|||
}
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
case OnnxDType::FLOAT:
|
||||
switch (dtype) {
|
||||
case 0:
|
||||
return DataType::Undefine;
|
||||
case 1:
|
||||
return DataType::Float32;
|
||||
case OnnxDType::UINT32:
|
||||
return DataType::UInt32;
|
||||
case OnnxDType::UINT8:
|
||||
case 2:
|
||||
return DataType::UInt8;
|
||||
case OnnxDType::INT8:
|
||||
case 3:
|
||||
return DataType::Int8;
|
||||
case OnnxDType::UINT16:
|
||||
case 4:
|
||||
return DataType::UInt16;
|
||||
case OnnxDType::INT16:
|
||||
case 5:
|
||||
return DataType::Int16;
|
||||
case OnnxDType::INT32:
|
||||
case 6:
|
||||
return DataType::Int32;
|
||||
case OnnxDType::INT64:
|
||||
case 7:
|
||||
return DataType::Int64;
|
||||
case OnnxDType::BOOL:
|
||||
case 8:
|
||||
return DataType::String;
|
||||
case 9:
|
||||
return DataType::Bool;
|
||||
case OnnxDType::FLOAT16:
|
||||
case 10:
|
||||
return DataType::Float16;
|
||||
case 11:
|
||||
return DataType::Double;
|
||||
case 12:
|
||||
return DataType::UInt32;
|
||||
case 13:
|
||||
return DataType::UInt64;
|
||||
default:
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
|
|
@ -71,14 +71,20 @@ void TensorObj::printData() const {
|
|||
if (dtype == DataType(N)) \
|
||||
std::cout << dataToString<DT<N>::t>() << std::endl;
|
||||
|
||||
TRY_PRINT(0) // fmt: new line
|
||||
else TRY_PRINT(1) //
|
||||
else TRY_PRINT(2) //
|
||||
else TRY_PRINT(3) //
|
||||
else TRY_PRINT(4) //
|
||||
else TRY_PRINT(5) //
|
||||
else TRY_PRINT(6) //
|
||||
else TRY_PRINT(7) //
|
||||
TRY_PRINT(0) // fmt: new line
|
||||
else TRY_PRINT(1) //
|
||||
else TRY_PRINT(2) //
|
||||
else TRY_PRINT(3) //
|
||||
else TRY_PRINT(4) //
|
||||
else TRY_PRINT(5) //
|
||||
else TRY_PRINT(6) //
|
||||
else TRY_PRINT(7) //
|
||||
else TRY_PRINT(8) //
|
||||
else TRY_PRINT(9) //
|
||||
else TRY_PRINT(10) //
|
||||
else TRY_PRINT(11) //
|
||||
else TRY_PRINT(12) //
|
||||
else TRY_PRINT(13) //
|
||||
else IT_TODO_HALT();
|
||||
|
||||
#undef TRY_PRINT
|
||||
|
@ -98,14 +104,20 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
|
|||
return equalDataImpl(getRawDataPtr<DT<N>::t *>(), \
|
||||
rhs->getRawDataPtr<DT<N>::t *>(), size());
|
||||
|
||||
TEST_EQUAL(0) // fmt: new line
|
||||
else TEST_EQUAL(1) //
|
||||
else TEST_EQUAL(2) //
|
||||
else TEST_EQUAL(3) //
|
||||
else TEST_EQUAL(4) //
|
||||
else TEST_EQUAL(5) //
|
||||
else TEST_EQUAL(6) //
|
||||
else TEST_EQUAL(7) //
|
||||
TEST_EQUAL(0) // fmt: new line
|
||||
else TEST_EQUAL(1) //
|
||||
else TEST_EQUAL(2) //
|
||||
else TEST_EQUAL(3) //
|
||||
else TEST_EQUAL(4) //
|
||||
else TEST_EQUAL(5) //
|
||||
else TEST_EQUAL(6) //
|
||||
else TEST_EQUAL(7) //
|
||||
else TEST_EQUAL(8) //
|
||||
else TEST_EQUAL(9) //
|
||||
else TEST_EQUAL(10) //
|
||||
else TEST_EQUAL(11) //
|
||||
else TEST_EQUAL(12) //
|
||||
else TEST_EQUAL(13) //
|
||||
else IT_TODO_HALT();
|
||||
|
||||
#undef TEST_EQUAL
|
||||
|
|
|
@ -100,26 +100,34 @@ void export_values(py::module &m) {
|
|||
}
|
||||
|
||||
static int tensor_dtype(Tensor t) {
|
||||
if (t->getDType() == DataType::Undefine)
|
||||
return 0;
|
||||
if (t->getDType() == DataType::Float32)
|
||||
return OnnxDType::FLOAT;
|
||||
if (t->getDType() == DataType::UInt32)
|
||||
return OnnxDType::UINT32;
|
||||
return 1;
|
||||
if (t->getDType() == DataType::UInt8)
|
||||
return OnnxDType::UINT8;
|
||||
return 2;
|
||||
if (t->getDType() == DataType::Int8)
|
||||
return OnnxDType::INT8;
|
||||
return 3;
|
||||
if (t->getDType() == DataType::UInt16)
|
||||
return OnnxDType::UINT16;
|
||||
return 4;
|
||||
if (t->getDType() == DataType::Int16)
|
||||
return OnnxDType::INT16;
|
||||
return 5;
|
||||
if (t->getDType() == DataType::Int32)
|
||||
return OnnxDType::INT32;
|
||||
return 6;
|
||||
if (t->getDType() == DataType::Int64)
|
||||
return OnnxDType::INT64;
|
||||
return 7;
|
||||
if (t->getDType() == DataType::String)
|
||||
return 8;
|
||||
if (t->getDType() == DataType::Bool)
|
||||
return OnnxDType::BOOL;
|
||||
return 9;
|
||||
if (t->getDType() == DataType::Float16)
|
||||
return OnnxDType::FLOAT16;
|
||||
return 10;
|
||||
if (t->getDType() == DataType::Double)
|
||||
return 11;
|
||||
if (t->getDType() == DataType::UInt32)
|
||||
return 12;
|
||||
if (t->getDType() == DataType::UInt64)
|
||||
return 13;
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
|
@ -288,9 +296,13 @@ void init_graph_builder(py::module &m) {
|
|||
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
|
||||
.def("copyin_int8", &TensorObj::copyin<int8_t>, policy::move)
|
||||
.def("copyin_uint8", &TensorObj::copyin<uint8_t>, policy::move)
|
||||
.def("copyin_float16", &TensorObj::copyin<uint16_t>, policy::move)
|
||||
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
|
||||
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
|
||||
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
|
||||
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
|
||||
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
|
||||
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic);
|
||||
|
|
|
@ -7,9 +7,9 @@ namespace infini {
|
|||
TEST(Handler, matmul) {
|
||||
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||
auto handler = make_ref<GraphHandlerObj>(runtime);
|
||||
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
|
||||
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
|
||||
auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32);
|
||||
auto i = handler->tensor({1, 2, 3}, DataType::UInt32.getIndex());
|
||||
auto w = handler->tensor({1, 3, 4}, DataType::UInt32.getIndex());
|
||||
auto o = handler->tensor({1, 2, 4}, DataType::UInt32.getIndex());
|
||||
handler->matmul(i, w, o, false, false, nullptr, ActType::None);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue