fix: 按照review意见修改

This commit is contained in:
zhangyunze 2023-07-24 16:58:38 +08:00
parent c4aec6c38b
commit e936ee2850
8 changed files with 164 additions and 121 deletions

View File

@ -5,30 +5,43 @@ namespace infini {
class DataType {
public:
// <https://onnx.ai/onnx/intro/concepts.html#element-type>
static const DataType Undefine;
static const DataType Float32;
static const DataType UInt32;
static const DataType UInt8;
static const DataType Int8;
static const DataType UInt16;
static const DataType Int16;
static const DataType Int32;
static const DataType Int64;
static const DataType String;
static const DataType Bool;
static const DataType Float16;
static const DataType Double;
static const DataType UInt32;
static const DataType UInt64;
// "sizePerElement" show the DType to cpu_type
// DataType::Bool -> int8_t DataType::Float16 -> uint8_t
static constexpr size_t sizePerElement[]{
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t),
sizeof(int8_t), sizeof(uint8_t)};
// DataType::Bool -> int8_t DataType::Float16 -> uint16_t
static constexpr size_t sizePerElement[]{0,
sizeof(float),
sizeof(uint8_t),
sizeof(int8_t),
sizeof(uint16_t),
sizeof(int16_t),
sizeof(int32_t),
sizeof(int64_t),
sizeof(std::string),
sizeof(int8_t),
sizeof(uint16_t),
sizeof(double),
sizeof(uint32_t),
sizeof(uint64_t)};
static constexpr std::string_view names[]{
"Float32", "UInt32", "UInt8", "Int8", "UInt16",
"Int16", "Int32", "Int64", "Bool", "Float16"};
"Undefine", "Float32", "UInt8", "Int8", "UInt16",
"Int16", "Int32", "Int64", "String", "Bool",
"Float16", "Double", "UInt32", "UInt64"};
static constexpr std::string_view cpuType[]{
"float", "uint32_t", "uint8_t", "int8_t", "uint16_t",
"int16_t", "int32_t", "int64_t", "int8_t", "uint8_t"};
static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8};
private:
int index;
@ -41,40 +54,58 @@ class DataType {
bool operator==(const DataType &rhs) const { return index == rhs.index; }
bool operator<(const DataType &rhs) const { return index < rhs.index; }
template <typename T> static std::string get() {
template <typename T> static int get() {
IT_TODO_HALT_MSG("Unsupported data type");
}
size_t getSize() const { return sizePerElement[index]; }
string toString() const { return string(names[index]); }
string cpuTypeString() const { return string(cpuType[index]); }
int cpuTypeInt() const { return cpuType[index]; }
int getIndex() const { return index; }
};
inline const DataType DataType::Float32(0);
inline const DataType DataType::UInt32(1);
inline const DataType DataType::UInt8(2), DataType::Int8(3),
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
DataType::Int64(7), DataType::Bool(8), DataType::Float16(9);
// to be consistent with onnx
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
inline const DataType DataType::Undefine(0);
inline const DataType DataType::Float32(1);
inline const DataType DataType::UInt8(2);
inline const DataType DataType::Int8(3);
inline const DataType DataType::UInt16(4);
inline const DataType DataType::Int16(5);
inline const DataType DataType::Int32(6);
inline const DataType DataType::Int64(7);
inline const DataType DataType::String(8);
inline const DataType DataType::Bool(9);
inline const DataType DataType::Float16(10);
inline const DataType DataType::Double(11);
inline const DataType DataType::UInt32(12);
inline const DataType DataType::UInt64(13);
// Method definitions are out of the declaration due to GCC bug:
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
template <> inline std::string DataType::get<float>() { return "float"; }
template <> inline std::string DataType::get<uint32_t>() { return "uint32_t"; }
template <> inline std::string DataType::get<uint8_t>() { return "uint8_t"; }
template <> inline std::string DataType::get<int8_t>() { return "int8_t"; }
template <> inline std::string DataType::get<uint16_t>() { return "uint16_t"; }
template <> inline std::string DataType::get<int16_t>() { return "int16_t"; }
template <> inline std::string DataType::get<int32_t>() { return "int32_t"; }
template <> inline std::string DataType::get<int64_t>() { return "int64_t"; }
template <> inline int DataType::get<float>() { return 0; }
template <> inline int DataType::get<uint32_t>() { return 1; }
template <> inline int DataType::get<uint8_t>() { return 2; }
template <> inline int DataType::get<int8_t>() { return 3; }
template <> inline int DataType::get<uint16_t>() { return 4; }
template <> inline int DataType::get<int16_t>() { return 5; }
template <> inline int DataType::get<int32_t>() { return 6; }
template <> inline int DataType::get<int64_t>() { return 7; }
template <> inline int DataType::get<uint64_t>() { return 8; }
template <> inline int DataType::get<double>() { return 9; }
template <int index> struct DT {};
template <> struct DT<0> { using t = float; };
template <> struct DT<1> { using t = uint32_t; };
template <> struct DT<0> { using t = bool; };
template <> struct DT<1> { using t = float; };
template <> struct DT<2> { using t = uint8_t; };
template <> struct DT<3> { using t = int8_t; };
template <> struct DT<4> { using t = uint16_t; };
template <> struct DT<5> { using t = int16_t; };
template <> struct DT<6> { using t = int32_t; };
template <> struct DT<7> { using t = int64_t; };
template <> struct DT<8> { using t = int8_t; };
template <> struct DT<9> { using t = uint16_t; };
template <> struct DT<8> { using t = char; };
template <> struct DT<9> { using t = int8_t; };
template <> struct DT<10> { using t = uint16_t; };
template <> struct DT<11> { using t = double; };
template <> struct DT<12> { using t = uint32_t; };
template <> struct DT<13> { using t = uint64_t; };
} // namespace infini

View File

@ -7,30 +7,6 @@
namespace infini {
// Use the indices from onnx to reduce delivery overhead,
// which comes from onnx but may be not only used for onnx.
//
// see https://onnx.ai/onnx/intro/concepts.html#element-type
enum OnnxDType : int {
UNDEFINED = 0,
FLOAT,
UINT8,
INT8,
UINT16,
INT16,
INT32,
INT64,
STRING,
BOOL,
FLOAT16,
DOUBLE,
UINT32,
UINT64,
COMPLEX64,
COMPLEX128,
BFLOAT16,
};
class GraphHandlerObj {
Graph g;

View File

@ -32,10 +32,7 @@ class TensorObj : public TensorBaseObj {
string toString() const override;
size_t size() const { return _size; }
size_t getBytes() const {
size_t usebytes = _size * dtype.getSize();
return dtype == DataType::Float16 ? usebytes * 2 : usebytes;
}
size_t getBytes() const { return _size * dtype.getSize(); }
Shape getDims() const { return shape; }
vector<size_t> getStride() const;
@ -48,24 +45,20 @@ class TensorObj : public TensorBaseObj {
// Copy elements from `data`.
template <typename T> void copyin(const vector<T> &data) {
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
IT_ASSERT(data.size() >= _size);
copyin(data.data(), getBytes());
}
// Copy all the elements to a vector.
template <typename T> auto copyout() const {
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
auto sizeofvec = _size;
if (dtype == DataType::Float16) {
sizeofvec = _size * 2;
}
std::vector<T> ans(sizeofvec);
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
std::vector<T> ans(_size);
copyout(ans.data(), getBytes());
return ans;
}
// Copy the element at `pos`.
template <typename T> auto copyOne(const vector<int> &pos) const {
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
auto offset = getOffset(pos);
auto bytes = dtype.getSize();
T ans;
@ -105,7 +98,7 @@ class TensorObj : public TensorBaseObj {
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
template <typename T> bool equalData(const vector<T> &dataVector) {
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeString());
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
IT_ASSERT(size() == dataVector.size());
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
}

View File

@ -1,4 +1,5 @@
import backend
import struct
from onnx import (
ModelProto,
TensorProto,
@ -477,7 +478,10 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
next(
(attr.ints for attr in node.attribute if attr.name == "axes"),
None,
),
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
!= 0,
)
@ -534,17 +538,7 @@ class OnnxStub:
elif tensor.data_type == TensorProto.BOOL:
obj.copyin_int8(_parse_data(tensor))
elif tensor.data_type == TensorProto.FLOAT16:
if len(tensor.int32_data) != 0:
list_int32_data = []
for element_data in tensor.int32_data:
element_byte = element_data.to_bytes(2, "little")
list_int32_data.append(element_byte[0])
list_int32_data.append(element_byte[1])
obj.copyin_uint8(list_int32_data)
elif len(tensor.raw_data) != 0:
obj.copyin_uint8(list(tensor.raw_data))
else :
raise Exception("Tensor have no float16 data!")
obj.copyin_float16(_parse_data_fp16(tensor))
elif tensor.data_type == TensorProto.INT8:
obj.copyin_uint8(_parse_data(tensor))
else:
@ -910,5 +904,22 @@ def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[st
def _parse_data(tensor: TensorProto) -> List[Any]:
return to_array(tensor).flatten().tolist()
def _parse_data_fp16(tensor: TensorProto):
list_ = []
if len(tensor.int32_data) != 0:
for element_data in tensor.int32_data:
element_byte = element_data.to_bytes(2, "little")
list_.append(element_byte[0] + element_byte[1] * 256)
elif len(tensor.raw_data) != 0:
list_raw_data = list(tensor.raw_data)
list_data = [list_raw_data[i : i + 2] for i in range(0, len(list_raw_data), 2)]
for ele in list_data:
list_.append(ele[0] + ele[1] * 256)
else:
raise Exception("Tensor have no float16 data!")
return list_
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]

View File

@ -292,27 +292,35 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
}
static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT:
switch (dtype) {
case 0:
return DataType::Undefine;
case 1:
return DataType::Float32;
case OnnxDType::UINT32:
return DataType::UInt32;
case OnnxDType::UINT8:
case 2:
return DataType::UInt8;
case OnnxDType::INT8:
case 3:
return DataType::Int8;
case OnnxDType::UINT16:
case 4:
return DataType::UInt16;
case OnnxDType::INT16:
case 5:
return DataType::Int16;
case OnnxDType::INT32:
case 6:
return DataType::Int32;
case OnnxDType::INT64:
case 7:
return DataType::Int64;
case OnnxDType::BOOL:
case 8:
return DataType::String;
case 9:
return DataType::Bool;
case OnnxDType::FLOAT16:
case 10:
return DataType::Float16;
case 11:
return DataType::Double;
case 12:
return DataType::UInt32;
case 13:
return DataType::UInt64;
default:
IT_ASSERT(false, "Unsupported data type");
}

View File

@ -71,14 +71,20 @@ void TensorObj::printData() const {
if (dtype == DataType(N)) \
std::cout << dataToString<DT<N>::t>() << std::endl;
TRY_PRINT(0) // fmt: new line
else TRY_PRINT(1) //
else TRY_PRINT(2) //
else TRY_PRINT(3) //
else TRY_PRINT(4) //
else TRY_PRINT(5) //
else TRY_PRINT(6) //
else TRY_PRINT(7) //
TRY_PRINT(0) // fmt: new line
else TRY_PRINT(1) //
else TRY_PRINT(2) //
else TRY_PRINT(3) //
else TRY_PRINT(4) //
else TRY_PRINT(5) //
else TRY_PRINT(6) //
else TRY_PRINT(7) //
else TRY_PRINT(8) //
else TRY_PRINT(9) //
else TRY_PRINT(10) //
else TRY_PRINT(11) //
else TRY_PRINT(12) //
else TRY_PRINT(13) //
else IT_TODO_HALT();
#undef TRY_PRINT
@ -98,14 +104,20 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
return equalDataImpl(getRawDataPtr<DT<N>::t *>(), \
rhs->getRawDataPtr<DT<N>::t *>(), size());
TEST_EQUAL(0) // fmt: new line
else TEST_EQUAL(1) //
else TEST_EQUAL(2) //
else TEST_EQUAL(3) //
else TEST_EQUAL(4) //
else TEST_EQUAL(5) //
else TEST_EQUAL(6) //
else TEST_EQUAL(7) //
TEST_EQUAL(0) // fmt: new line
else TEST_EQUAL(1) //
else TEST_EQUAL(2) //
else TEST_EQUAL(3) //
else TEST_EQUAL(4) //
else TEST_EQUAL(5) //
else TEST_EQUAL(6) //
else TEST_EQUAL(7) //
else TEST_EQUAL(8) //
else TEST_EQUAL(9) //
else TEST_EQUAL(10) //
else TEST_EQUAL(11) //
else TEST_EQUAL(12) //
else TEST_EQUAL(13) //
else IT_TODO_HALT();
#undef TEST_EQUAL

View File

@ -100,26 +100,34 @@ void export_values(py::module &m) {
}
static int tensor_dtype(Tensor t) {
if (t->getDType() == DataType::Undefine)
return 0;
if (t->getDType() == DataType::Float32)
return OnnxDType::FLOAT;
if (t->getDType() == DataType::UInt32)
return OnnxDType::UINT32;
return 1;
if (t->getDType() == DataType::UInt8)
return OnnxDType::UINT8;
return 2;
if (t->getDType() == DataType::Int8)
return OnnxDType::INT8;
return 3;
if (t->getDType() == DataType::UInt16)
return OnnxDType::UINT16;
return 4;
if (t->getDType() == DataType::Int16)
return OnnxDType::INT16;
return 5;
if (t->getDType() == DataType::Int32)
return OnnxDType::INT32;
return 6;
if (t->getDType() == DataType::Int64)
return OnnxDType::INT64;
return 7;
if (t->getDType() == DataType::String)
return 8;
if (t->getDType() == DataType::Bool)
return OnnxDType::BOOL;
return 9;
if (t->getDType() == DataType::Float16)
return OnnxDType::FLOAT16;
return 10;
if (t->getDType() == DataType::Double)
return 11;
if (t->getDType() == DataType::UInt32)
return 12;
if (t->getDType() == DataType::UInt64)
return 13;
IT_ASSERT(false, "Unsupported data type");
}
@ -288,9 +296,13 @@ void init_graph_builder(py::module &m) {
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
.def("copyin_int8", &TensorObj::copyin<int8_t>, policy::move)
.def("copyin_uint8", &TensorObj::copyin<uint8_t>, policy::move)
.def("copyin_float16", &TensorObj::copyin<uint16_t>, policy::move)
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
.def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getSource, policy::move)
.def("printData", &TensorObj::printData, policy::automatic);

View File

@ -7,9 +7,9 @@ namespace infini {
TEST(Handler, matmul) {
auto runtime = NativeCpuRuntimeObj::getInstance();
auto handler = make_ref<GraphHandlerObj>(runtime);
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32);
auto i = handler->tensor({1, 2, 3}, DataType::UInt32.getIndex());
auto w = handler->tensor({1, 3, 4}, DataType::UInt32.getIndex());
auto o = handler->tensor({1, 2, 4}, DataType::UInt32.getIndex());
handler->matmul(i, w, o, false, false, nullptr, ActType::None);
}