From 9b10a74788ec08bd58cd9bfa42954addc47606d4 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Wed, 2 Aug 2023 16:38:16 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81fp16=20dtype=20(#96)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add conv_half kernel * Conv Kernel FP16 * dcj: replace "DataType::Float32" with "op->getDType()" to support more DataType * feat: support Float16 dtype * fix: set default clang-format to 14 version * fix: 按照review意见修改 * fix: add data convert to convfp16 kernel test * test: add conv_fp16 kernel test --------- Co-authored-by: zhangyue207 Co-authored-by: kilinchange --- include/core/data_type.h | 100 ++++++--- include/core/graph_handler.h | 24 -- include/core/tensor.h | 27 ++- include/operators/reshape.h | 1 + include/utils/data_convert.h | 11 + include/utils/data_generator.h | 11 + pyinfinitensor/src/pyinfinitensor/onnx.py | 30 ++- pyinfinitensor/tests/test_onnx.py | 21 +- src/bang/bang_runtime.cc | 3 +- src/core/graph_handler.cc | 32 ++- src/core/operator.cc | 2 - src/core/perf_engine.cc | 6 +- src/core/tensor.cc | 44 ++-- src/cuda/cuda_runtime.cc | 6 +- src/ffi/ffi_infinitensor.cc | 44 +++- src/kernels/cuda/conv_half.cc | 261 ++++++++++++++++++++++ src/utils/data_convert.cc | 30 +++ test/core/test_graph_handler.cc | 6 +- test/kernels/cuda/test_cuda_conv_fp16.cc | 78 +++++++ 19 files changed, 626 insertions(+), 111 deletions(-) create mode 100644 include/utils/data_convert.h create mode 100644 src/kernels/cuda/conv_half.cc create mode 100644 src/utils/data_convert.cc create mode 100644 test/kernels/cuda/test_cuda_conv_fp16.cc diff --git a/include/core/data_type.h b/include/core/data_type.h index 878f4bdb..2fb05a07 100644 --- a/include/core/data_type.h +++ b/include/core/data_type.h @@ -4,19 +4,44 @@ namespace infini { class DataType { public: - // legacy - static const DataType Float32; - static const DataType UInt32; - // These are just aligned with the type and index of onnx: // - static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64; - static constexpr size_t sizePerElement[]{ - sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t), - sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)}; + static const DataType Undefine; + static const DataType Float32; + static const DataType UInt8; + static const DataType Int8; + static const DataType UInt16; + static const DataType Int16; + static const DataType Int32; + static const DataType Int64; + static const DataType String; + static const DataType Bool; + static const DataType Float16; + static const DataType Double; + static const DataType UInt32; + static const DataType UInt64; + // "sizePerElement" show the DType to cpu_type + // DataType::Bool -> int8_t DataType::Float16 -> uint16_t + static constexpr size_t sizePerElement[]{0, + sizeof(float), + sizeof(uint8_t), + sizeof(int8_t), + sizeof(uint16_t), + sizeof(int16_t), + sizeof(int32_t), + sizeof(int64_t), + sizeof(std::string), + sizeof(int8_t), + sizeof(uint16_t), + sizeof(double), + sizeof(uint32_t), + sizeof(uint64_t)}; - static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8", - "Int8", "UInt16", "Int16", - "Int32", "Int64"}; + static constexpr std::string_view names[]{ + "Undefine", "Float32", "UInt8", "Int8", "UInt16", + "Int16", "Int32", "Int64", "String", "Bool", + "Float16", "Double", "UInt32", "UInt64"}; + + static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8}; private: int index; @@ -29,37 +54,58 @@ class DataType { bool operator==(const DataType &rhs) const { return index == rhs.index; } bool operator<(const DataType &rhs) const { return index < rhs.index; } - template static DataType get() { + template static int get() { IT_TODO_HALT_MSG("Unsupported data type"); } size_t getSize() const { return sizePerElement[index]; } string toString() const { return string(names[index]); } + int cpuTypeInt() const { return cpuType[index]; } + int getIndex() const { return index; } }; -inline const DataType DataType::Float32(0); -inline const DataType DataType::UInt32(1); -inline const DataType DataType::UInt8(2), DataType::Int8(3), - DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6), - DataType::Int64(7); +// to be consistent with onnx +// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484 +inline const DataType DataType::Undefine(0); +inline const DataType DataType::Float32(1); +inline const DataType DataType::UInt8(2); +inline const DataType DataType::Int8(3); +inline const DataType DataType::UInt16(4); +inline const DataType DataType::Int16(5); +inline const DataType DataType::Int32(6); +inline const DataType DataType::Int64(7); +inline const DataType DataType::String(8); +inline const DataType DataType::Bool(9); +inline const DataType DataType::Float16(10); +inline const DataType DataType::Double(11); +inline const DataType DataType::UInt32(12); +inline const DataType DataType::UInt64(13); // Method definitions are out of the declaration due to GCC bug: // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc -template <> inline DataType DataType::get() { return Float32; } -template <> inline DataType DataType::get() { return UInt32; } -template <> inline DataType DataType::get() { return UInt8; } -template <> inline DataType DataType::get() { return Int8; } -template <> inline DataType DataType::get() { return UInt16; } -template <> inline DataType DataType::get() { return Int16; } -template <> inline DataType DataType::get() { return Int32; } -template <> inline DataType DataType::get() { return Int64; } +template <> inline int DataType::get() { return 0; } +template <> inline int DataType::get() { return 1; } +template <> inline int DataType::get() { return 2; } +template <> inline int DataType::get() { return 3; } +template <> inline int DataType::get() { return 4; } +template <> inline int DataType::get() { return 5; } +template <> inline int DataType::get() { return 6; } +template <> inline int DataType::get() { return 7; } +template <> inline int DataType::get() { return 8; } +template <> inline int DataType::get() { return 9; } template struct DT {}; -template <> struct DT<0> { using t = float; }; -template <> struct DT<1> { using t = uint32_t; }; +template <> struct DT<0> { using t = bool; }; +template <> struct DT<1> { using t = float; }; template <> struct DT<2> { using t = uint8_t; }; template <> struct DT<3> { using t = int8_t; }; template <> struct DT<4> { using t = uint16_t; }; template <> struct DT<5> { using t = int16_t; }; template <> struct DT<6> { using t = int32_t; }; template <> struct DT<7> { using t = int64_t; }; +template <> struct DT<8> { using t = char; }; +template <> struct DT<9> { using t = int8_t; }; +template <> struct DT<10> { using t = uint16_t; }; +template <> struct DT<11> { using t = double; }; +template <> struct DT<12> { using t = uint32_t; }; +template <> struct DT<13> { using t = uint64_t; }; } // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 2c9b25b4..ac0be526 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -7,30 +7,6 @@ namespace infini { -// Use the indices from onnx to reduce delivery overhead, -// which comes from onnx but may be not only used for onnx. -// -// see https://onnx.ai/onnx/intro/concepts.html#element-type -enum OnnxDType : int { - UNDEFINED = 0, - FLOAT, - UINT8, - INT8, - UINT16, - INT16, - INT32, - INT64, - STRING, - BOOL, - FLOAT16, - DOUBLE, - UINT32, - UINT64, - COMPLEX64, - COMPLEX128, - BFLOAT16, -}; - class GraphHandlerObj { Graph g; diff --git a/include/core/tensor.h b/include/core/tensor.h index a1081e15..d2fad79e 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,5 +1,6 @@ #pragma once #include "core/tensor_base.h" +#include "utils/data_convert.h" #include #include @@ -45,20 +46,20 @@ class TensorObj : public TensorBaseObj { // Copy elements from `data`. template void copyin(const vector &data) { - IT_ASSERT(DataType::get() == dtype); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); IT_ASSERT(data.size() >= _size); copyin(data.data(), getBytes()); } // Copy all the elements to a vector. template auto copyout() const { - IT_ASSERT(DataType::get() == dtype); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); std::vector ans(_size); copyout(ans.data(), getBytes()); return ans; } // Copy the element at `pos`. template auto copyOne(const vector &pos) const { - IT_ASSERT(DataType::get() == dtype); + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); auto offset = getOffset(pos); auto bytes = dtype.getSize(); T ans; @@ -98,8 +99,12 @@ class TensorObj : public TensorBaseObj { bool equalData(const Tensor &rhs, double relativeError = 1e-6) const; template bool equalData(const vector &dataVector) { - IT_ASSERT(DataType::get() == dtype); IT_ASSERT(size() == dataVector.size()); + if (dtype == DataType::Float16) { + return equalDataImpl_fp16(getRawDataPtr(), + (float *)dataVector.data(), size()); + } + IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); return equalDataImpl(getRawDataPtr(), dataVector.data(), size()); } @@ -156,6 +161,20 @@ class TensorObj : public TensorBaseObj { return true; } + bool equalDataImpl_fp16(const uint16_t *a, const float *b, + size_t size) const { + for (size_t i = 0; i < size; ++i) { + auto a_fp32 = fp16_to_float(a[i]); + auto b_fp32 = b[i]; + if (fabs(a_fp32 - b_fp32) / std::max(fabs(a_fp32), fabs(b_fp32)) > + 1e-6) { + printf("Error on %lu: %f %f\n", i, a_fp32, b_fp32); + return false; + } + } + return true; + } + Shape getPosByOffset(size_t offset, Shape dim) const; size_t getOffsetByPos(Shape pos, Shape dim) const; diff --git a/include/operators/reshape.h b/include/operators/reshape.h index 907bbcbb..00ae5b0a 100644 --- a/include/operators/reshape.h +++ b/include/operators/reshape.h @@ -60,6 +60,7 @@ class FlattenObj : public OperatorObj { std::string toString() const override; int numInputs() const override { return 1; } int numOutputs() const override { return 1; } + int getAxis() const { return axis; } private: vector getWorkloadVector() const override; diff --git a/include/utils/data_convert.h b/include/utils/data_convert.h new file mode 100644 index 00000000..51dd8501 --- /dev/null +++ b/include/utils/data_convert.h @@ -0,0 +1,11 @@ +#pragma once +#include + +namespace infini { +union Uf32 { + float f32; + uint32_t u32; +}; +uint16_t float_to_fp16(const float x); +float fp16_to_float(const uint16_t x); +} // namespace infini diff --git a/include/utils/data_generator.h b/include/utils/data_generator.h index 89d8b84c..982db835 100644 --- a/include/utils/data_generator.h +++ b/include/utils/data_generator.h @@ -1,6 +1,7 @@ #pragma once #include "core/common.h" #include "core/tensor_base.h" +#include "utils/data_convert.h" #include namespace infini { @@ -10,6 +11,7 @@ class DataGenerator { private: virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); } virtual void fill(float *data, size_t size) { IT_TODO_HALT(); } + virtual void fill_fp16(uint16_t *data, size_t size) { IT_TODO_HALT(); } public: virtual ~DataGenerator() {} @@ -18,6 +20,8 @@ class DataGenerator { fill(reinterpret_cast(data), size); else if (dataType == DataType::Float32) fill(reinterpret_cast(data), size); + else if (dataType == DataType::Float16) + fill_fp16(reinterpret_cast(data), size); else IT_TODO_HALT(); } @@ -38,6 +42,13 @@ class IncrementalGenerator : public DataGenerator { fill(data, size); } void fill(float *data, size_t size) override { fill(data, size); } + // FIXME: fix the accuracy standards when dtype is float16 + void fill_fp16(uint16_t *data, size_t size) { + for (size_t i = 0; i < size; i++) { + float x = 2.0f; + data[i] = float_to_fp16(x); + } + } }; class RandomGenerator : public DataGenerator { diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index b57390db..6f686b58 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -477,7 +477,10 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), # NOTE(constroy): `axes` is an attribute until opset version 13. - next((attr.ints for attr in node.attribute if attr.name == "axes"), None), + next( + (attr.ints for attr in node.attribute if attr.name == "axes"), + None, + ), next((attr.i for attr in node.attribute if attr.name == "keepdims")) != 0, ) @@ -531,6 +534,12 @@ class OnnxStub: obj.copyin_int64(_parse_data(tensor)) elif tensor.data_type == TensorProto.FLOAT: obj.copyin_float(_parse_data(tensor)) + elif tensor.data_type == TensorProto.BOOL: + obj.copyin_int8(_parse_data(tensor)) + elif tensor.data_type == TensorProto.FLOAT16: + obj.copyin_float16(_parse_data_fp16(tensor)) + elif tensor.data_type == TensorProto.INT8: + obj.copyin_uint8(_parse_data(tensor)) else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) @@ -730,7 +739,8 @@ class OnnxStub: ]: ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpType.Flatten: - raise Exception("TODO") + axis = backend.flatten_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) elif ty == backend.OpType.Transpose: perm = backend.transpose_permute_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, perm=perm)) @@ -894,5 +904,21 @@ def _parse_data(tensor: TensorProto) -> List[Any]: return to_array(tensor).flatten().tolist() +def _parse_data_fp16(tensor: TensorProto): + list_ = [] + if len(tensor.int32_data) != 0: + for element_data in tensor.int32_data: + element_byte = element_data.to_bytes(2, "little") + list_.append(element_byte[0] + element_byte[1] * 256) + elif len(tensor.raw_data) != 0: + list_raw_data = list(tensor.raw_data) + list_data = [list_raw_data[i : i + 2] for i in range(0, len(list_raw_data), 2)] + for ele in list_data: + list_.append(ele[0] + ele[1] * 256) + else: + raise Exception("Tensor have no float16 data!") + return list_ + + def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index fd589eeb..166d0df2 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -64,6 +64,21 @@ class TestStringMethods(unittest.TestCase): ) make_and_import_model(make_graph([conv], "conv", [i, w], [o])) + def test_conv_fp16(self): + i = make_tensor_value_info("i", TensorProto.FLOAT16, [1, 3, 4, 4]) + w = make_tensor_value_info("w", TensorProto.FLOAT16, [2, 3, 3, 3]) + o = make_tensor_value_info("o", TensorProto.FLOAT16, [1, 2, 2, 2]) + conv = make_node( + "Conv", + ["i", "w"], + ["o"], + "conv", + pads=[1, 1, 1, 1], + strides=[2, 1], + dilations=[1, 2], + ) + make_and_import_model(make_graph([conv], "conv_fp16", [i, w], [o])) + def test_matmul(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) @@ -211,9 +226,9 @@ class TestStringMethods(unittest.TestCase): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3, 5 * 7]) flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten") - # make_and_import_model( - make_graph([flatten], "flatten", [x], [y]) - # ) + make_and_import_model( + make_graph([flatten], "flatten", [x], [y]) + ) def test_reshape(self): data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) diff --git a/src/bang/bang_runtime.cc b/src/bang/bang_runtime.cc index 8f71f1b6..66e2b9b0 100644 --- a/src/bang/bang_runtime.cc +++ b/src/bang/bang_runtime.cc @@ -13,8 +13,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::Float32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 61db36ac..d032edd6 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -292,23 +292,35 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output, } static DataType dtype_repr_convert(int dtype) { - switch ((OnnxDType)dtype) { - case OnnxDType::FLOAT: + switch (dtype) { + case 0: + return DataType::Undefine; + case 1: return DataType::Float32; - case OnnxDType::UINT32: - return DataType::UInt32; - case OnnxDType::UINT8: + case 2: return DataType::UInt8; - case OnnxDType::INT8: + case 3: return DataType::Int8; - case OnnxDType::UINT16: + case 4: return DataType::UInt16; - case OnnxDType::INT16: + case 5: return DataType::Int16; - case OnnxDType::INT32: + case 6: return DataType::Int32; - case OnnxDType::INT64: + case 7: return DataType::Int64; + case 8: + return DataType::String; + case 9: + return DataType::Bool; + case 10: + return DataType::Float16; + case 11: + return DataType::Double; + case 12: + return DataType::UInt32; + case 13: + return DataType::UInt64; default: IT_ASSERT(false, "Unsupported data type"); } diff --git a/src/core/operator.cc b/src/core/operator.cc index cea23321..743c49bd 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -110,8 +110,6 @@ optional> OperatorObj::inferShape() const { vector OperatorObj::inferDataType(const TensorVec &inputs) const { auto dataType = inputs[0]->getDType(); - for (const auto &tensor : inputs) - IT_ASSERT(dataType == tensor->getDType()); return vector(numOutputs(), dataType); } diff --git a/src/core/perf_engine.cc b/src/core/perf_engine.cc index ecf97f66..1a838b3b 100644 --- a/src/core/perf_engine.cc +++ b/src/core/perf_engine.cc @@ -30,9 +30,7 @@ void from_json(const json &j, OpPerfKey &p) { j.at("opType").get_to(p.opType); j.at("attrs").get_to(p.attrs); } -void to_json(json &j, const DataType &p) { - j = p.toString() == "Float32" ? 0 : 1; -} +void to_json(json &j, const DataType &p) { j = p.getIndex(); } void from_json(const json &j, DataType &p) { p = DataType(j.get()); } void to_json(json &j, const PerfRecord &p) { p->to_json(j); } void from_json(const json &j, PerfRecord &p) { @@ -49,4 +47,4 @@ void from_json(const json &j, PerfEngine &p) { p.set_data(tmp); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index e63039d5..c80ff8f7 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -71,14 +71,20 @@ void TensorObj::printData() const { if (dtype == DataType(N)) \ std::cout << dataToString::t>() << std::endl; - TRY_PRINT(0) // fmt: new line - else TRY_PRINT(1) // - else TRY_PRINT(2) // - else TRY_PRINT(3) // - else TRY_PRINT(4) // - else TRY_PRINT(5) // - else TRY_PRINT(6) // - else TRY_PRINT(7) // + TRY_PRINT(0) // fmt: new line + else TRY_PRINT(1) // + else TRY_PRINT(2) // + else TRY_PRINT(3) // + else TRY_PRINT(4) // + else TRY_PRINT(5) // + else TRY_PRINT(6) // + else TRY_PRINT(7) // + else TRY_PRINT(8) // + else TRY_PRINT(9) // + else TRY_PRINT(10) // + else TRY_PRINT(11) // + else TRY_PRINT(12) // + else TRY_PRINT(13) // else IT_TODO_HALT(); #undef TRY_PRINT @@ -98,14 +104,20 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const { return equalDataImpl(getRawDataPtr::t *>(), \ rhs->getRawDataPtr::t *>(), size()); - TEST_EQUAL(0) // fmt: new line - else TEST_EQUAL(1) // - else TEST_EQUAL(2) // - else TEST_EQUAL(3) // - else TEST_EQUAL(4) // - else TEST_EQUAL(5) // - else TEST_EQUAL(6) // - else TEST_EQUAL(7) // + TEST_EQUAL(0) // fmt: new line + else TEST_EQUAL(1) // + else TEST_EQUAL(2) // + else TEST_EQUAL(3) // + else TEST_EQUAL(4) // + else TEST_EQUAL(5) // + else TEST_EQUAL(6) // + else TEST_EQUAL(7) // + else TEST_EQUAL(8) // + else TEST_EQUAL(9) // + else TEST_EQUAL(10) // + else TEST_EQUAL(11) // + else TEST_EQUAL(12) // + else TEST_EQUAL(13) // else IT_TODO_HALT(); #undef TEST_EQUAL diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 37b5e7cf..06a53c4e 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -11,8 +11,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { auto &perfEngine = PerfEngine::getInstance(); for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::Float32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); @@ -33,8 +32,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::Float32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 1da7203a..9ca5d995 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -100,22 +100,34 @@ void export_values(py::module &m) { } static int tensor_dtype(Tensor t) { + if (t->getDType() == DataType::Undefine) + return 0; if (t->getDType() == DataType::Float32) - return OnnxDType::FLOAT; - if (t->getDType() == DataType::UInt32) - return OnnxDType::UINT32; + return 1; if (t->getDType() == DataType::UInt8) - return OnnxDType::UINT8; + return 2; if (t->getDType() == DataType::Int8) - return OnnxDType::INT8; + return 3; if (t->getDType() == DataType::UInt16) - return OnnxDType::UINT16; + return 4; if (t->getDType() == DataType::Int16) - return OnnxDType::INT16; + return 5; if (t->getDType() == DataType::Int32) - return OnnxDType::INT32; + return 6; if (t->getDType() == DataType::Int64) - return OnnxDType::INT64; + return 7; + if (t->getDType() == DataType::String) + return 8; + if (t->getDType() == DataType::Bool) + return 9; + if (t->getDType() == DataType::Float16) + return 10; + if (t->getDType() == DataType::Double) + return 11; + if (t->getDType() == DataType::UInt32) + return 12; + if (t->getDType() == DataType::UInt64) + return 13; IT_ASSERT(false, "Unsupported data type"); } @@ -224,6 +236,11 @@ static vector transpose_permute_of(Operator op) { return dynamic_cast(op.get())->getPermute(); } +static int flatten_axis_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Flatten); + return dynamic_cast(op.get())->getAxis(); +} + void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) @@ -252,7 +269,8 @@ void export_functions(py::module &m) { .FUNCTION(transpose_permute_of) .FUNCTION(concat_axis_of) .FUNCTION(split_axis_of) - .FUNCTION(gather_axis_of); + .FUNCTION(gather_axis_of) + .FUNCTION(flatten_axis_of); #undef FUNCTION } @@ -276,9 +294,15 @@ void init_graph_builder(py::module &m) { .def("copyin_float", &TensorObj::copyin, policy::move) .def("copyin_int32", &TensorObj::copyin, policy::move) .def("copyin_int64", &TensorObj::copyin, policy::move) + .def("copyin_int8", &TensorObj::copyin, policy::move) + .def("copyin_uint8", &TensorObj::copyin, policy::move) + .def("copyin_float16", &TensorObj::copyin, policy::move) .def("copyout_float", &TensorObj::copyout, policy::move) .def("copyout_int32", &TensorObj::copyout, policy::move) .def("copyout_int64", &TensorObj::copyout, policy::move) + .def("copyout_int8", &TensorObj::copyout, policy::move) + .def("copyout_uint8", &TensorObj::copyout, policy::move) + .def("copyout_float16", &TensorObj::copyout, policy::move) .def("has_target", &TensorObj::hasTarget, policy::automatic) .def("src", &TensorObj::getSource, policy::move) .def("printData", &TensorObj::printData, policy::automatic); diff --git a/src/kernels/cuda/conv_half.cc b/src/kernels/cuda/conv_half.cc new file mode 100644 index 00000000..1f83b484 --- /dev/null +++ b/src/kernels/cuda/conv_half.cc @@ -0,0 +1,261 @@ +#include "core/kernel.h" +#include "cuda/cuda_runtime.h" +#include "operators/conv.h" +#include +#include +#include +#include + +namespace infini { + +struct ConvCuDnnPerfRecordObj : public PerfRecordObj { + int algo = 0; // cudnnConvolutionFwdAlgo_t + int mode = 1; + size_t workspaceSize = 100000; + bool fuseAct = false; + void to_json(json &j) override { + j["type"] = 1; + j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize); + } + static PerfRecord from_json(const json &j) { + ConvCuDnnPerfRecordObj tmp; + auto [Algo, Mode, FuseAct, Time, WorkspaceSize] = + j["data"].get>(); + tmp.algo = Algo; + tmp.mode = Mode; + tmp.fuseAct = FuseAct; + tmp.time = Time; + tmp.workspaceSize = WorkspaceSize; + return make_ref(tmp); + } +}; + +using ConvCuDnnPerfRecord = Ref; + +class convCudnnFP16 : public Kernel { + + static constexpr int N_ALGO = 8; + static constexpr int N_MODE = 2; + static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = { + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; + + static constexpr cudnnConvolutionMode_t MODES[2] = { + CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION}; + + std::tuple + createCuDNNDescriptor(const Ref &op, + const ConvCuDnnPerfRecord &record) const { + void *const inData = (op->getInputs(0)->getRawDataPtr()); + void *const knData = (op->getInputs(1)->getRawDataPtr()); + // Bias is not supported yet + if (op->getInputs().size() > 2) { + IT_TODO_HALT(); + } + // void *const biasData = (op->getInputs(2)->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); + + const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + const int cpg = op->getChannelPerGroup(); + const int g = c / cpg; + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + + int channelsPerGrp = cpg, channels = c; + + // get inputs + cudnnTensorDescriptor_t inDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW, + CUDNN_DATA_HALF, n, channels, + h, w)); /*fp16 type*/ + + // get kernels + cudnnFilterDescriptor_t knDesc; + checkCudnnError(cudnnCreateFilterDescriptor(&knDesc)); + checkCudnnError(cudnnSetFilter4dDescriptor( + knDesc, CUDNN_DATA_HALF, /*fp16 type*/ + CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s)); + // get bias + cudnnTensorDescriptor_t biasDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW, + CUDNN_DATA_HALF, 1, f, 1, + 1)); /*fp16 type*/ + + // get convolution descriptor + cudnnConvolutionDescriptor_t convDesc; + checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc)); + // TODO: CUDNN_CONVOLUTION is a tunable argument + checkCudnnError(cudnnSetConvolution2dDescriptor( + convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode], + CUDNN_DATA_HALF)); /*fp16 type*/ + if (g > 1) { + checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g)); + } + + // get activation descriptor + cudnnActivationDescriptor_t actDesc; + checkCudnnError(cudnnCreateActivationDescriptor(&actDesc)); + // NOT_PROPAGATE_NAN is requierd by + // cudnnConvolotionBiasActivationForward + switch (op->getAct()) { + case ActType::Relu: + checkCudnnError(cudnnSetActivationDescriptor( + actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0)); + break; + case ActType::Sigmoid: + checkCudnnError(cudnnSetActivationDescriptor( + actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0)); + break; + case ActType::None: + checkCudnnError( + cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY, + CUDNN_NOT_PROPAGATE_NAN, 0)); + break; + default: + assert(false); + } + + // get output descriptor + int outn, outc, outh, outw; + checkCudnnError(cudnnGetConvolution2dForwardOutputDim( + convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw)); + cudnnTensorDescriptor_t outDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW, + CUDNN_DATA_HALF, outn, outc, + outh, outw)); + IT_ASSERT((vector{outn, outc, outh, outw}) == + op->getOutput()->getDims(), + "cuDNN output shape mismatches with OP output shape"); + + return tuple(inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc); + } + + bool cuDNNUnfused(const Ref &op, const ConvCuDnnPerfRecord &record, + const CudaRuntimeObj *context) const { + cudnnStatus_t stat; + + const auto &[inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc] = + createCuDNNDescriptor(op, record); + size_t wsSize = record->workspaceSize; + CudaPtr wsData = context->getWorkspace(wsSize); + float alpha = 1.f, beta = 0.f; + + stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc, + inData, knDesc, knData, convDesc, + ALGOS[record->algo], wsData, wsSize, + &beta, outDesc, outData); + if (stat != CUDNN_STATUS_SUCCESS) { + return false; + } + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); + checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); + checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + return true; + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + auto record = make_ref(); // with paramters in + // default ctor + compute(op, record, context); + } + + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + ConvCuDnnPerfRecordObj ret; + ret.time = std::numeric_limits::max(); + auto context = dynamic_cast(_context); + auto op = as(_op); + // Both modes have the same performance. Only run cross-correlation. + for (int mode = 1; mode < 2; mode++) { + // Try every possible algorithm of convolution + for (int algo = 0; algo < N_ALGO; algo++) { + auto recordRef = make_ref(); + auto &record = *recordRef; + record.mode = mode; + record.algo = algo; + cudnnStatus_t stat; + const auto &[inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc] = + createCuDNNDescriptor(op, recordRef); + + // get workspace + stat = cudnnGetConvolutionForwardWorkspaceSize( + context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc, + ALGOS[record.algo], &record.workspaceSize); + if (stat != CUDNN_STATUS_SUCCESS) { + continue; + } + if (record.workspaceSize > context->getWorkspaceSize()) { + continue; + } + CudaPtr wsData = context->getWorkspace(record.workspaceSize); + float alpha = 1.f, beta = 0.f; + + stat = cudnnConvolutionForward( + context->cudnnHandle(), &alpha, inDesc, inData, knDesc, + knData, convDesc, ALGOS[record.algo], wsData, + record.workspaceSize, &beta, outDesc, outData); + if (stat != CUDNN_STATUS_SUCCESS) { + continue; + } + record.time = timeit( + [&]() { + cudnnConvolutionForward(context->cudnnHandle(), &alpha, + inDesc, inData, knDesc, knData, + convDesc, ALGOS[record.algo], + wsData, record.workspaceSize, + &beta, outDesc, outData); + }, + [&]() { context->sync(); }); + // printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time); + + // Update the tune result + if (ret.time > record.time) { + ret = record; + } + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); + checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); + checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + } + } + // printf("the best algo is %d, the best conv mode is %d\n", ret.algo, + // ret.mode); + IT_ASSERT(ret.time < std::numeric_limits::max(), "No valid " + "algorithm " + "found"); + return make_ref(ret); + } + + void compute(const Operator &_op, const PerfRecord &_record, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto record = as(_record); + auto context = dynamic_cast(_context); + bool success = cuDNNUnfused(op, record, context); + IT_ASSERT(success); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16, + "Conv_cuDNN_CUDA_Float16"); + +} // namespace infini diff --git a/src/utils/data_convert.cc b/src/utils/data_convert.cc new file mode 100644 index 00000000..28b0c923 --- /dev/null +++ b/src/utils/data_convert.cc @@ -0,0 +1,30 @@ +#include "utils/data_convert.h" + +namespace infini { + +uint16_t float_to_fp16(const float x) { + Uf32 u; + u.f32 = x; + const uint32_t b = u.u32 + 0x00001000; + const uint32_t e = (b & 0x7F800000) >> 23; + const uint32_t m = b & 0x007FFFFF; + return (b & 0x80000000) >> 16 | + (e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) | + ((e < 113) & (e > 101)) * + ((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) | + (e > 143) * 0x7FFF; +} + +float fp16_to_float(const uint16_t x) { + Uf32 u; + const uint32_t e = (x & 0x7C00) >> 10; + const uint32_t m = (x & 0x03FF) << 13; + u.f32 = (float)m; + const uint32_t v = u.u32 >> 23; + const uint32_t r = (x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) | + ((e == 0) & (m != 0)) * + ((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000)); + u.u32 = r; + return u.f32; +} +} // namespace infini diff --git a/test/core/test_graph_handler.cc b/test/core/test_graph_handler.cc index b5dce89b..c25ce5d2 100644 --- a/test/core/test_graph_handler.cc +++ b/test/core/test_graph_handler.cc @@ -7,9 +7,9 @@ namespace infini { TEST(Handler, matmul) { auto runtime = NativeCpuRuntimeObj::getInstance(); auto handler = make_ref(runtime); - auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32); - auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32); - auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32); + auto i = handler->tensor({1, 2, 3}, DataType::UInt32.getIndex()); + auto w = handler->tensor({1, 3, 4}, DataType::UInt32.getIndex()); + auto o = handler->tensor({1, 2, 4}, DataType::UInt32.getIndex()); handler->matmul(i, w, o, false, false, nullptr, ActType::None); } diff --git a/test/kernels/cuda/test_cuda_conv_fp16.cc b/test/kernels/cuda/test_cuda_conv_fp16.cc new file mode 100644 index 00000000..994e2dee --- /dev/null +++ b/test/kernels/cuda/test_cuda_conv_fp16.cc @@ -0,0 +1,78 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" +#include + +#include "test.h" + +namespace infini { + +void testConvCudnnFP16( + const std::function &generator, + vector ansVec) { + + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float16); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float16); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + gCuda->print(); +} + +TEST(cuDNN_Conv_FP16, run) { + testConvCudnnFP16(IncrementalGenerator(), + vector{48, 48, 72, 72, 48, 48, 72, 72}); +} + +TEST(cuDNN_Conv_FP16, tune) { + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float16); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float16); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); +} +} // namespace infini