diff --git a/include/core/data_type.h b/include/core/data_type.h index 190368ab..2fb05a07 100644 --- a/include/core/data_type.h +++ b/include/core/data_type.h @@ -5,30 +5,43 @@ namespace infini { class DataType { public: // + static const DataType Undefine; static const DataType Float32; - static const DataType UInt32; static const DataType UInt8; static const DataType Int8; static const DataType UInt16; static const DataType Int16; static const DataType Int32; static const DataType Int64; + static const DataType String; static const DataType Bool; static const DataType Float16; + static const DataType Double; + static const DataType UInt32; + static const DataType UInt64; // "sizePerElement" show the DType to cpu_type - // DataType::Bool -> int8_t DataType::Float16 -> uint8_t - static constexpr size_t sizePerElement[]{ - sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t), - sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t), - sizeof(int8_t), sizeof(uint8_t)}; + // DataType::Bool -> int8_t DataType::Float16 -> uint16_t + static constexpr size_t sizePerElement[]{0, + sizeof(float), + sizeof(uint8_t), + sizeof(int8_t), + sizeof(uint16_t), + sizeof(int16_t), + sizeof(int32_t), + sizeof(int64_t), + sizeof(std::string), + sizeof(int8_t), + sizeof(uint16_t), + sizeof(double), + sizeof(uint32_t), + sizeof(uint64_t)}; static constexpr std::string_view names[]{ - "Float32", "UInt32", "UInt8", "Int8", "UInt16", - "Int16", "Int32", "Int64", "Bool", "Float16"}; + "Undefine", "Float32", "UInt8", "Int8", "UInt16", + "Int16", "Int32", "Int64", "String", "Bool", + "Float16", "Double", "UInt32", "UInt64"}; - static constexpr std::string_view cpuType[]{ - "float", "uint32_t", "uint8_t", "int8_t", "uint16_t", - "int16_t", "int32_t", "int64_t", "int8_t", "uint8_t"}; + static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8}; private: int index; @@ -41,40 +54,58 @@ class DataType { bool operator==(const DataType &rhs) const { return index == rhs.index; } bool operator<(const DataType &rhs) const { return index < rhs.index; } - template static std::string get() { + template static int get() { IT_TODO_HALT_MSG("Unsupported data type"); } size_t getSize() const { return sizePerElement[index]; } string toString() const { return string(names[index]); } - string cpuTypeString() const { return string(cpuType[index]); } + int cpuTypeInt() const { return cpuType[index]; } + int getIndex() const { return index; } }; -inline const DataType DataType::Float32(0); -inline const DataType DataType::UInt32(1); -inline const DataType DataType::UInt8(2), DataType::Int8(3), - DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6), - DataType::Int64(7), DataType::Bool(8), DataType::Float16(9); +// to be consistent with onnx +// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484 +inline const DataType DataType::Undefine(0); +inline const DataType DataType::Float32(1); +inline const DataType DataType::UInt8(2); +inline const DataType DataType::Int8(3); +inline const DataType DataType::UInt16(4); +inline const DataType DataType::Int16(5); +inline const DataType DataType::Int32(6); +inline const DataType DataType::Int64(7); +inline const DataType DataType::String(8); +inline const DataType DataType::Bool(9); +inline const DataType DataType::Float16(10); +inline const DataType DataType::Double(11); +inline const DataType DataType::UInt32(12); +inline const DataType DataType::UInt64(13); // Method definitions are out of the declaration due to GCC bug: // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc -template <> inline std::string DataType::get() { return "float"; } -template <> inline std::string DataType::get() { return "uint32_t"; } -template <> inline std::string DataType::get() { return "uint8_t"; } -template <> inline std::string DataType::get() { return "int8_t"; } -template <> inline std::string DataType::get() { return "uint16_t"; } -template <> inline std::string DataType::get() { return "int16_t"; } -template <> inline std::string DataType::get() { return "int32_t"; } -template <> inline std::string DataType::get() { return "int64_t"; } +template <> inline int DataType::get() { return 0; } +template <> inline int DataType::get() { return 1; } +template <> inline int DataType::get() { return 2; } +template <> inline int DataType::get() { return 3; } +template <> inline int DataType::get() { return 4; } +template <> inline int DataType::get() { return 5; } +template <> inline int DataType::get() { return 6; } +template <> inline int DataType::get() { return 7; } +template <> inline int DataType::get() { return 8; } +template <> inline int DataType::get() { return 9; } template struct DT {}; -template <> struct DT<0> { using t = float; }; -template <> struct DT<1> { using t = uint32_t; }; +template <> struct DT<0> { using t = bool; }; +template <> struct DT<1> { using t = float; }; template <> struct DT<2> { using t = uint8_t; }; template <> struct DT<3> { using t = int8_t; }; template <> struct DT<4> { using t = uint16_t; }; template <> struct DT<5> { using t = int16_t; }; template <> struct DT<6> { using t = int32_t; }; template <> struct DT<7> { using t = int64_t; }; -template <> struct DT<8> { using t = int8_t; }; -template <> struct DT<9> { using t = uint16_t; }; +template <> struct DT<8> { using t = char; }; +template <> struct DT<9> { using t = int8_t; }; +template <> struct DT<10> { using t = uint16_t; }; +template <> struct DT<11> { using t = double; }; +template <> struct DT<12> { using t = uint32_t; }; +template <> struct DT<13> { using t = uint64_t; }; } // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 2c9b25b4..ac0be526 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -7,30 +7,6 @@ namespace infini { -// Use the indices from onnx to reduce delivery overhead, -// which comes from onnx but may be not only used for onnx. -// -// see https://onnx.ai/onnx/intro/concepts.html#element-type -enum OnnxDType : int { - UNDEFINED = 0, - FLOAT, - UINT8, - INT8, - UINT16, - INT16, - INT32, - INT64, - STRING, - BOOL, - FLOAT16, - DOUBLE, - UINT32, - UINT64, - COMPLEX64, - COMPLEX128, - BFLOAT16, -}; - class GraphHandlerObj { Graph g; diff --git a/include/core/tensor.h b/include/core/tensor.h index bb592bff..4707f3e9 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -32,10 +32,7 @@ class TensorObj : public TensorBaseObj { string toString() const override; size_t size() const { return _size; } - size_t getBytes() const { - size_t usebytes = _size * dtype.getSize(); - return dtype == DataType::Float16 ? usebytes * 2 : usebytes; - } + size_t getBytes() const { return _size * dtype.getSize(); } Shape getDims() const { return shape; } vector getStride() const; @@ -48,24 +45,20 @@ class TensorObj : public TensorBaseObj { // Copy elements from `data`. template void copyin(const vector &data) { - IT_ASSERT(DataType::get() == dtype.cpuTypeString()); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); IT_ASSERT(data.size() >= _size); copyin(data.data(), getBytes()); } // Copy all the elements to a vector. template auto copyout() const { - IT_ASSERT(DataType::get() == dtype.cpuTypeString()); - auto sizeofvec = _size; - if (dtype == DataType::Float16) { - sizeofvec = _size * 2; - } - std::vector ans(sizeofvec); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); + std::vector ans(_size); copyout(ans.data(), getBytes()); return ans; } // Copy the element at `pos`. template auto copyOne(const vector &pos) const { - IT_ASSERT(DataType::get() == dtype.cpuTypeString()); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); auto offset = getOffset(pos); auto bytes = dtype.getSize(); T ans; @@ -105,7 +98,7 @@ class TensorObj : public TensorBaseObj { bool equalData(const Tensor &rhs, double relativeError = 1e-6) const; template bool equalData(const vector &dataVector) { - IT_ASSERT(DataType::get() == dtype.cpuTypeString()); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); IT_ASSERT(size() == dataVector.size()); return equalDataImpl(getRawDataPtr(), dataVector.data(), size()); } diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a5977878..7b88b83c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,4 +1,5 @@ import backend +import struct from onnx import ( ModelProto, TensorProto, @@ -477,7 +478,10 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), # NOTE(constroy): `axes` is an attribute until opset version 13. - next((attr.ints for attr in node.attribute if attr.name == "axes"), None), + next( + (attr.ints for attr in node.attribute if attr.name == "axes"), + None, + ), next((attr.i for attr in node.attribute if attr.name == "keepdims")) != 0, ) @@ -534,17 +538,7 @@ class OnnxStub: elif tensor.data_type == TensorProto.BOOL: obj.copyin_int8(_parse_data(tensor)) elif tensor.data_type == TensorProto.FLOAT16: - if len(tensor.int32_data) != 0: - list_int32_data = [] - for element_data in tensor.int32_data: - element_byte = element_data.to_bytes(2, "little") - list_int32_data.append(element_byte[0]) - list_int32_data.append(element_byte[1]) - obj.copyin_uint8(list_int32_data) - elif len(tensor.raw_data) != 0: - obj.copyin_uint8(list(tensor.raw_data)) - else : - raise Exception("Tensor have no float16 data!") + obj.copyin_float16(_parse_data_fp16(tensor)) elif tensor.data_type == TensorProto.INT8: obj.copyin_uint8(_parse_data(tensor)) else: @@ -910,5 +904,22 @@ def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[st def _parse_data(tensor: TensorProto) -> List[Any]: return to_array(tensor).flatten().tolist() + +def _parse_data_fp16(tensor: TensorProto): + list_ = [] + if len(tensor.int32_data) != 0: + for element_data in tensor.int32_data: + element_byte = element_data.to_bytes(2, "little") + list_.append(element_byte[0] + element_byte[1] * 256) + elif len(tensor.raw_data) != 0: + list_raw_data = list(tensor.raw_data) + list_data = [list_raw_data[i : i + 2] for i in range(0, len(list_raw_data), 2)] + for ele in list_data: + list_.append(ele[0] + ele[1] * 256) + else: + raise Exception("Tensor have no float16 data!") + return list_ + + def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 3dc29877..d032edd6 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -292,27 +292,35 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output, } static DataType dtype_repr_convert(int dtype) { - switch ((OnnxDType)dtype) { - case OnnxDType::FLOAT: + switch (dtype) { + case 0: + return DataType::Undefine; + case 1: return DataType::Float32; - case OnnxDType::UINT32: - return DataType::UInt32; - case OnnxDType::UINT8: + case 2: return DataType::UInt8; - case OnnxDType::INT8: + case 3: return DataType::Int8; - case OnnxDType::UINT16: + case 4: return DataType::UInt16; - case OnnxDType::INT16: + case 5: return DataType::Int16; - case OnnxDType::INT32: + case 6: return DataType::Int32; - case OnnxDType::INT64: + case 7: return DataType::Int64; - case OnnxDType::BOOL: + case 8: + return DataType::String; + case 9: return DataType::Bool; - case OnnxDType::FLOAT16: + case 10: return DataType::Float16; + case 11: + return DataType::Double; + case 12: + return DataType::UInt32; + case 13: + return DataType::UInt64; default: IT_ASSERT(false, "Unsupported data type"); } diff --git a/src/core/tensor.cc b/src/core/tensor.cc index e63039d5..c80ff8f7 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -71,14 +71,20 @@ void TensorObj::printData() const { if (dtype == DataType(N)) \ std::cout << dataToString::t>() << std::endl; - TRY_PRINT(0) // fmt: new line - else TRY_PRINT(1) // - else TRY_PRINT(2) // - else TRY_PRINT(3) // - else TRY_PRINT(4) // - else TRY_PRINT(5) // - else TRY_PRINT(6) // - else TRY_PRINT(7) // + TRY_PRINT(0) // fmt: new line + else TRY_PRINT(1) // + else TRY_PRINT(2) // + else TRY_PRINT(3) // + else TRY_PRINT(4) // + else TRY_PRINT(5) // + else TRY_PRINT(6) // + else TRY_PRINT(7) // + else TRY_PRINT(8) // + else TRY_PRINT(9) // + else TRY_PRINT(10) // + else TRY_PRINT(11) // + else TRY_PRINT(12) // + else TRY_PRINT(13) // else IT_TODO_HALT(); #undef TRY_PRINT @@ -98,14 +104,20 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const { return equalDataImpl(getRawDataPtr::t *>(), \ rhs->getRawDataPtr::t *>(), size()); - TEST_EQUAL(0) // fmt: new line - else TEST_EQUAL(1) // - else TEST_EQUAL(2) // - else TEST_EQUAL(3) // - else TEST_EQUAL(4) // - else TEST_EQUAL(5) // - else TEST_EQUAL(6) // - else TEST_EQUAL(7) // + TEST_EQUAL(0) // fmt: new line + else TEST_EQUAL(1) // + else TEST_EQUAL(2) // + else TEST_EQUAL(3) // + else TEST_EQUAL(4) // + else TEST_EQUAL(5) // + else TEST_EQUAL(6) // + else TEST_EQUAL(7) // + else TEST_EQUAL(8) // + else TEST_EQUAL(9) // + else TEST_EQUAL(10) // + else TEST_EQUAL(11) // + else TEST_EQUAL(12) // + else TEST_EQUAL(13) // else IT_TODO_HALT(); #undef TEST_EQUAL diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 82eb406b..9ca5d995 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -100,26 +100,34 @@ void export_values(py::module &m) { } static int tensor_dtype(Tensor t) { + if (t->getDType() == DataType::Undefine) + return 0; if (t->getDType() == DataType::Float32) - return OnnxDType::FLOAT; - if (t->getDType() == DataType::UInt32) - return OnnxDType::UINT32; + return 1; if (t->getDType() == DataType::UInt8) - return OnnxDType::UINT8; + return 2; if (t->getDType() == DataType::Int8) - return OnnxDType::INT8; + return 3; if (t->getDType() == DataType::UInt16) - return OnnxDType::UINT16; + return 4; if (t->getDType() == DataType::Int16) - return OnnxDType::INT16; + return 5; if (t->getDType() == DataType::Int32) - return OnnxDType::INT32; + return 6; if (t->getDType() == DataType::Int64) - return OnnxDType::INT64; + return 7; + if (t->getDType() == DataType::String) + return 8; if (t->getDType() == DataType::Bool) - return OnnxDType::BOOL; + return 9; if (t->getDType() == DataType::Float16) - return OnnxDType::FLOAT16; + return 10; + if (t->getDType() == DataType::Double) + return 11; + if (t->getDType() == DataType::UInt32) + return 12; + if (t->getDType() == DataType::UInt64) + return 13; IT_ASSERT(false, "Unsupported data type"); } @@ -288,9 +296,13 @@ void init_graph_builder(py::module &m) { .def("copyin_int64", &TensorObj::copyin, policy::move) .def("copyin_int8", &TensorObj::copyin, policy::move) .def("copyin_uint8", &TensorObj::copyin, policy::move) + .def("copyin_float16", &TensorObj::copyin, policy::move) .def("copyout_float", &TensorObj::copyout, policy::move) .def("copyout_int32", &TensorObj::copyout, policy::move) .def("copyout_int64", &TensorObj::copyout, policy::move) + .def("copyout_int8", &TensorObj::copyout, policy::move) + .def("copyout_uint8", &TensorObj::copyout, policy::move) + .def("copyout_float16", &TensorObj::copyout, policy::move) .def("has_target", &TensorObj::hasTarget, policy::automatic) .def("src", &TensorObj::getSource, policy::move) .def("printData", &TensorObj::printData, policy::automatic); diff --git a/test/core/test_graph_handler.cc b/test/core/test_graph_handler.cc index b5dce89b..c25ce5d2 100644 --- a/test/core/test_graph_handler.cc +++ b/test/core/test_graph_handler.cc @@ -7,9 +7,9 @@ namespace infini { TEST(Handler, matmul) { auto runtime = NativeCpuRuntimeObj::getInstance(); auto handler = make_ref(runtime); - auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32); - auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32); - auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32); + auto i = handler->tensor({1, 2, 3}, DataType::UInt32.getIndex()); + auto w = handler->tensor({1, 3, 4}, DataType::UInt32.getIndex()); + auto o = handler->tensor({1, 2, 4}, DataType::UInt32.getIndex()); handler->matmul(i, w, o, false, false, nullptr, ActType::None); }