支持fp16 dtype (#96)

* add conv_half kernel

* Conv Kernel FP16

* dcj:
replace "DataType::Float32" with "op->getDType()" to support more DataType

* feat: support Float16 dtype

* fix: set default clang-format to 14 version

* fix: 按照review意见修改

* fix: add data convert to convfp16 kernel test

* test: add conv_fp16 kernel test

---------

Co-authored-by: zhangyue207 <zhangyue@qiyuanlab.com>
Co-authored-by: kilinchange <kilinchange@163.com>
This commit is contained in:
zhangyunze 2023-08-02 16:38:16 +08:00 committed by GitHub
parent 1dc65e2788
commit 9b10a74788
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 626 additions and 111 deletions

View File

@ -4,19 +4,44 @@ namespace infini {
class DataType {
public:
// legacy
static const DataType Float32;
static const DataType UInt32;
// These are just aligned with the type and index of onnx:
// <https://onnx.ai/onnx/intro/concepts.html#element-type>
static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64;
static constexpr size_t sizePerElement[]{
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)};
static const DataType Undefine;
static const DataType Float32;
static const DataType UInt8;
static const DataType Int8;
static const DataType UInt16;
static const DataType Int16;
static const DataType Int32;
static const DataType Int64;
static const DataType String;
static const DataType Bool;
static const DataType Float16;
static const DataType Double;
static const DataType UInt32;
static const DataType UInt64;
// "sizePerElement" show the DType to cpu_type
// DataType::Bool -> int8_t DataType::Float16 -> uint16_t
static constexpr size_t sizePerElement[]{0,
sizeof(float),
sizeof(uint8_t),
sizeof(int8_t),
sizeof(uint16_t),
sizeof(int16_t),
sizeof(int32_t),
sizeof(int64_t),
sizeof(std::string),
sizeof(int8_t),
sizeof(uint16_t),
sizeof(double),
sizeof(uint32_t),
sizeof(uint64_t)};
static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8",
"Int8", "UInt16", "Int16",
"Int32", "Int64"};
static constexpr std::string_view names[]{
"Undefine", "Float32", "UInt8", "Int8", "UInt16",
"Int16", "Int32", "Int64", "String", "Bool",
"Float16", "Double", "UInt32", "UInt64"};
static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8};
private:
int index;
@ -29,37 +54,58 @@ class DataType {
bool operator==(const DataType &rhs) const { return index == rhs.index; }
bool operator<(const DataType &rhs) const { return index < rhs.index; }
template <typename T> static DataType get() {
template <typename T> static int get() {
IT_TODO_HALT_MSG("Unsupported data type");
}
size_t getSize() const { return sizePerElement[index]; }
string toString() const { return string(names[index]); }
int cpuTypeInt() const { return cpuType[index]; }
int getIndex() const { return index; }
};
inline const DataType DataType::Float32(0);
inline const DataType DataType::UInt32(1);
inline const DataType DataType::UInt8(2), DataType::Int8(3),
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
DataType::Int64(7);
// to be consistent with onnx
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
inline const DataType DataType::Undefine(0);
inline const DataType DataType::Float32(1);
inline const DataType DataType::UInt8(2);
inline const DataType DataType::Int8(3);
inline const DataType DataType::UInt16(4);
inline const DataType DataType::Int16(5);
inline const DataType DataType::Int32(6);
inline const DataType DataType::Int64(7);
inline const DataType DataType::String(8);
inline const DataType DataType::Bool(9);
inline const DataType DataType::Float16(10);
inline const DataType DataType::Double(11);
inline const DataType DataType::UInt32(12);
inline const DataType DataType::UInt64(13);
// Method definitions are out of the declaration due to GCC bug:
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
template <> inline DataType DataType::get<float>() { return Float32; }
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
template <> inline DataType DataType::get<uint8_t>() { return UInt8; }
template <> inline DataType DataType::get<int8_t>() { return Int8; }
template <> inline DataType DataType::get<uint16_t>() { return UInt16; }
template <> inline DataType DataType::get<int16_t>() { return Int16; }
template <> inline DataType DataType::get<int32_t>() { return Int32; }
template <> inline DataType DataType::get<int64_t>() { return Int64; }
template <> inline int DataType::get<float>() { return 0; }
template <> inline int DataType::get<uint32_t>() { return 1; }
template <> inline int DataType::get<uint8_t>() { return 2; }
template <> inline int DataType::get<int8_t>() { return 3; }
template <> inline int DataType::get<uint16_t>() { return 4; }
template <> inline int DataType::get<int16_t>() { return 5; }
template <> inline int DataType::get<int32_t>() { return 6; }
template <> inline int DataType::get<int64_t>() { return 7; }
template <> inline int DataType::get<uint64_t>() { return 8; }
template <> inline int DataType::get<double>() { return 9; }
template <int index> struct DT {};
template <> struct DT<0> { using t = float; };
template <> struct DT<1> { using t = uint32_t; };
template <> struct DT<0> { using t = bool; };
template <> struct DT<1> { using t = float; };
template <> struct DT<2> { using t = uint8_t; };
template <> struct DT<3> { using t = int8_t; };
template <> struct DT<4> { using t = uint16_t; };
template <> struct DT<5> { using t = int16_t; };
template <> struct DT<6> { using t = int32_t; };
template <> struct DT<7> { using t = int64_t; };
template <> struct DT<8> { using t = char; };
template <> struct DT<9> { using t = int8_t; };
template <> struct DT<10> { using t = uint16_t; };
template <> struct DT<11> { using t = double; };
template <> struct DT<12> { using t = uint32_t; };
template <> struct DT<13> { using t = uint64_t; };
} // namespace infini

View File

@ -7,30 +7,6 @@
namespace infini {
// Use the indices from onnx to reduce delivery overhead,
// which comes from onnx but may be not only used for onnx.
//
// see https://onnx.ai/onnx/intro/concepts.html#element-type
enum OnnxDType : int {
UNDEFINED = 0,
FLOAT,
UINT8,
INT8,
UINT16,
INT16,
INT32,
INT64,
STRING,
BOOL,
FLOAT16,
DOUBLE,
UINT32,
UINT64,
COMPLEX64,
COMPLEX128,
BFLOAT16,
};
class GraphHandlerObj {
Graph g;

View File

@ -1,5 +1,6 @@
#pragma once
#include "core/tensor_base.h"
#include "utils/data_convert.h"
#include <cmath>
#include <cstring>
@ -45,20 +46,20 @@ class TensorObj : public TensorBaseObj {
// Copy elements from `data`.
template <typename T> void copyin(const vector<T> &data) {
IT_ASSERT(DataType::get<T>() == dtype);
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
IT_ASSERT(data.size() >= _size);
copyin(data.data(), getBytes());
}
// Copy all the elements to a vector.
template <typename T> auto copyout() const {
IT_ASSERT(DataType::get<T>() == dtype);
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
std::vector<T> ans(_size);
copyout(ans.data(), getBytes());
return ans;
}
// Copy the element at `pos`.
template <typename T> auto copyOne(const vector<int> &pos) const {
IT_ASSERT(DataType::get<T>() == dtype);
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
auto offset = getOffset(pos);
auto bytes = dtype.getSize();
T ans;
@ -98,8 +99,12 @@ class TensorObj : public TensorBaseObj {
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
template <typename T> bool equalData(const vector<T> &dataVector) {
IT_ASSERT(DataType::get<T>() == dtype);
IT_ASSERT(size() == dataVector.size());
if (dtype == DataType::Float16) {
return equalDataImpl_fp16(getRawDataPtr<uint16_t *>(),
(float *)dataVector.data(), size());
}
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
}
@ -156,6 +161,20 @@ class TensorObj : public TensorBaseObj {
return true;
}
bool equalDataImpl_fp16(const uint16_t *a, const float *b,
size_t size) const {
for (size_t i = 0; i < size; ++i) {
auto a_fp32 = fp16_to_float(a[i]);
auto b_fp32 = b[i];
if (fabs(a_fp32 - b_fp32) / std::max(fabs(a_fp32), fabs(b_fp32)) >
1e-6) {
printf("Error on %lu: %f %f\n", i, a_fp32, b_fp32);
return false;
}
}
return true;
}
Shape getPosByOffset(size_t offset, Shape dim) const;
size_t getOffsetByPos(Shape pos, Shape dim) const;

View File

@ -60,6 +60,7 @@ class FlattenObj : public OperatorObj {
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
int getAxis() const { return axis; }
private:
vector<int> getWorkloadVector() const override;

View File

@ -0,0 +1,11 @@
#pragma once
#include <iostream>
namespace infini {
union Uf32 {
float f32;
uint32_t u32;
};
uint16_t float_to_fp16(const float x);
float fp16_to_float(const uint16_t x);
} // namespace infini

View File

@ -1,6 +1,7 @@
#pragma once
#include "core/common.h"
#include "core/tensor_base.h"
#include "utils/data_convert.h"
#include <random>
namespace infini {
@ -10,6 +11,7 @@ class DataGenerator {
private:
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
virtual void fill_fp16(uint16_t *data, size_t size) { IT_TODO_HALT(); }
public:
virtual ~DataGenerator() {}
@ -18,6 +20,8 @@ class DataGenerator {
fill(reinterpret_cast<uint32_t *>(data), size);
else if (dataType == DataType::Float32)
fill(reinterpret_cast<float *>(data), size);
else if (dataType == DataType::Float16)
fill_fp16(reinterpret_cast<uint16_t *>(data), size);
else
IT_TODO_HALT();
}
@ -38,6 +42,13 @@ class IncrementalGenerator : public DataGenerator {
fill<uint32_t>(data, size);
}
void fill(float *data, size_t size) override { fill<float>(data, size); }
// FIXME: fix the accuracy standards when dtype is float16
void fill_fp16(uint16_t *data, size_t size) {
for (size_t i = 0; i < size; i++) {
float x = 2.0f;
data[i] = float_to_fp16(x);
}
}
};
class RandomGenerator : public DataGenerator {

View File

@ -477,7 +477,10 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
next(
(attr.ints for attr in node.attribute if attr.name == "axes"),
None,
),
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
!= 0,
)
@ -531,6 +534,12 @@ class OnnxStub:
obj.copyin_int64(_parse_data(tensor))
elif tensor.data_type == TensorProto.FLOAT:
obj.copyin_float(_parse_data(tensor))
elif tensor.data_type == TensorProto.BOOL:
obj.copyin_int8(_parse_data(tensor))
elif tensor.data_type == TensorProto.FLOAT16:
obj.copyin_float16(_parse_data_fp16(tensor))
elif tensor.data_type == TensorProto.INT8:
obj.copyin_uint8(_parse_data(tensor))
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
@ -730,7 +739,8 @@ class OnnxStub:
]:
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
axis = backend.flatten_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Transpose:
perm = backend.transpose_permute_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, perm=perm))
@ -894,5 +904,21 @@ def _parse_data(tensor: TensorProto) -> List[Any]:
return to_array(tensor).flatten().tolist()
def _parse_data_fp16(tensor: TensorProto):
list_ = []
if len(tensor.int32_data) != 0:
for element_data in tensor.int32_data:
element_byte = element_data.to_bytes(2, "little")
list_.append(element_byte[0] + element_byte[1] * 256)
elif len(tensor.raw_data) != 0:
list_raw_data = list(tensor.raw_data)
list_data = [list_raw_data[i : i + 2] for i in range(0, len(list_raw_data), 2)]
for ele in list_data:
list_.append(ele[0] + ele[1] * 256)
else:
raise Exception("Tensor have no float16 data!")
return list_
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]

View File

@ -64,6 +64,21 @@ class TestStringMethods(unittest.TestCase):
)
make_and_import_model(make_graph([conv], "conv", [i, w], [o]))
def test_conv_fp16(self):
i = make_tensor_value_info("i", TensorProto.FLOAT16, [1, 3, 4, 4])
w = make_tensor_value_info("w", TensorProto.FLOAT16, [2, 3, 3, 3])
o = make_tensor_value_info("o", TensorProto.FLOAT16, [1, 2, 2, 2])
conv = make_node(
"Conv",
["i", "w"],
["o"],
"conv",
pads=[1, 1, 1, 1],
strides=[2, 1],
dilations=[1, 2],
)
make_and_import_model(make_graph([conv], "conv_fp16", [i, w], [o]))
def test_matmul(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
@ -211,9 +226,9 @@ class TestStringMethods(unittest.TestCase):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3, 5 * 7])
flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten")
# make_and_import_model(
make_and_import_model(
make_graph([flatten], "flatten", [x], [y])
# )
)
def test_reshape(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])

View File

@ -13,8 +13,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Float32};
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -292,23 +292,35 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
}
static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT:
switch (dtype) {
case 0:
return DataType::Undefine;
case 1:
return DataType::Float32;
case OnnxDType::UINT32:
return DataType::UInt32;
case OnnxDType::UINT8:
case 2:
return DataType::UInt8;
case OnnxDType::INT8:
case 3:
return DataType::Int8;
case OnnxDType::UINT16:
case 4:
return DataType::UInt16;
case OnnxDType::INT16:
case 5:
return DataType::Int16;
case OnnxDType::INT32:
case 6:
return DataType::Int32;
case OnnxDType::INT64:
case 7:
return DataType::Int64;
case 8:
return DataType::String;
case 9:
return DataType::Bool;
case 10:
return DataType::Float16;
case 11:
return DataType::Double;
case 12:
return DataType::UInt32;
case 13:
return DataType::UInt64;
default:
IT_ASSERT(false, "Unsupported data type");
}

View File

@ -110,8 +110,6 @@ optional<vector<Shape>> OperatorObj::inferShape() const {
vector<DataType> OperatorObj::inferDataType(const TensorVec &inputs) const {
auto dataType = inputs[0]->getDType();
for (const auto &tensor : inputs)
IT_ASSERT(dataType == tensor->getDType());
return vector(numOutputs(), dataType);
}

View File

@ -30,9 +30,7 @@ void from_json(const json &j, OpPerfKey &p) {
j.at("opType").get_to(p.opType);
j.at("attrs").get_to(p.attrs);
}
void to_json(json &j, const DataType &p) {
j = p.toString() == "Float32" ? 0 : 1;
}
void to_json(json &j, const DataType &p) { j = p.getIndex(); }
void from_json(const json &j, DataType &p) { p = DataType(j.get<int>()); }
void to_json(json &j, const PerfRecord &p) { p->to_json(j); }
void from_json(const json &j, PerfRecord &p) {

View File

@ -79,6 +79,12 @@ void TensorObj::printData() const {
else TRY_PRINT(5) //
else TRY_PRINT(6) //
else TRY_PRINT(7) //
else TRY_PRINT(8) //
else TRY_PRINT(9) //
else TRY_PRINT(10) //
else TRY_PRINT(11) //
else TRY_PRINT(12) //
else TRY_PRINT(13) //
else IT_TODO_HALT();
#undef TRY_PRINT
@ -106,6 +112,12 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
else TEST_EQUAL(5) //
else TEST_EQUAL(6) //
else TEST_EQUAL(7) //
else TEST_EQUAL(8) //
else TEST_EQUAL(9) //
else TEST_EQUAL(10) //
else TEST_EQUAL(11) //
else TEST_EQUAL(12) //
else TEST_EQUAL(13) //
else IT_TODO_HALT();
#undef TEST_EQUAL

View File

@ -11,8 +11,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
auto &perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Float32};
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -33,8 +32,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Float32};
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -100,22 +100,34 @@ void export_values(py::module &m) {
}
static int tensor_dtype(Tensor t) {
if (t->getDType() == DataType::Undefine)
return 0;
if (t->getDType() == DataType::Float32)
return OnnxDType::FLOAT;
if (t->getDType() == DataType::UInt32)
return OnnxDType::UINT32;
return 1;
if (t->getDType() == DataType::UInt8)
return OnnxDType::UINT8;
return 2;
if (t->getDType() == DataType::Int8)
return OnnxDType::INT8;
return 3;
if (t->getDType() == DataType::UInt16)
return OnnxDType::UINT16;
return 4;
if (t->getDType() == DataType::Int16)
return OnnxDType::INT16;
return 5;
if (t->getDType() == DataType::Int32)
return OnnxDType::INT32;
return 6;
if (t->getDType() == DataType::Int64)
return OnnxDType::INT64;
return 7;
if (t->getDType() == DataType::String)
return 8;
if (t->getDType() == DataType::Bool)
return 9;
if (t->getDType() == DataType::Float16)
return 10;
if (t->getDType() == DataType::Double)
return 11;
if (t->getDType() == DataType::UInt32)
return 12;
if (t->getDType() == DataType::UInt64)
return 13;
IT_ASSERT(false, "Unsupported data type");
}
@ -224,6 +236,11 @@ static vector<int> transpose_permute_of(Operator op) {
return dynamic_cast<const TransposeObj *>(op.get())->getPermute();
}
static int flatten_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Flatten);
return dynamic_cast<const FlattenObj *>(op.get())->getAxis();
}
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
@ -252,7 +269,8 @@ void export_functions(py::module &m) {
.FUNCTION(transpose_permute_of)
.FUNCTION(concat_axis_of)
.FUNCTION(split_axis_of)
.FUNCTION(gather_axis_of);
.FUNCTION(gather_axis_of)
.FUNCTION(flatten_axis_of);
#undef FUNCTION
}
@ -276,9 +294,15 @@ void init_graph_builder(py::module &m) {
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
.def("copyin_int8", &TensorObj::copyin<int8_t>, policy::move)
.def("copyin_uint8", &TensorObj::copyin<uint8_t>, policy::move)
.def("copyin_float16", &TensorObj::copyin<uint16_t>, policy::move)
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
.def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getSource, policy::move)
.def("printData", &TensorObj::printData, policy::automatic);

View File

@ -0,0 +1,261 @@
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "operators/conv.h"
#include <chrono>
#include <functional>
#include <limits>
#include <tuple>
namespace infini {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
int algo = 0; // cudnnConvolutionFwdAlgo_t
int mode = 1;
size_t workspaceSize = 100000;
bool fuseAct = false;
void to_json(json &j) override {
j["type"] = 1;
j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize);
}
static PerfRecord from_json(const json &j) {
ConvCuDnnPerfRecordObj tmp;
auto [Algo, Mode, FuseAct, Time, WorkspaceSize] =
j["data"].get<tuple<int, int, bool, double, size_t>>();
tmp.algo = Algo;
tmp.mode = Mode;
tmp.fuseAct = FuseAct;
tmp.time = Time;
tmp.workspaceSize = WorkspaceSize;
return make_ref<ConvCuDnnPerfRecordObj>(tmp);
}
};
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
class convCudnnFP16 : public Kernel {
static constexpr int N_ALGO = 8;
static constexpr int N_MODE = 2;
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = {
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
static constexpr cudnnConvolutionMode_t MODES[2] = {
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
cudnnTensorDescriptor_t>
createCuDNNDescriptor(const Ref<ConvObj> &op,
const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
// Bias is not supported yet
if (op->getInputs().size() > 2) {
IT_TODO_HALT();
}
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
const int cpg = op->getChannelPerGroup();
const int g = c / cpg;
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
int channelsPerGrp = cpg, channels = c;
// get inputs
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, n, channels,
h, w)); /*fp16 type*/
// get kernels
cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
checkCudnnError(cudnnSetFilter4dDescriptor(
knDesc, CUDNN_DATA_HALF, /*fp16 type*/
CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
// get bias
cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, 1, f, 1,
1)); /*fp16 type*/
// get convolution descriptor
cudnnConvolutionDescriptor_t convDesc;
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
CUDNN_DATA_HALF)); /*fp16 type*/
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
}
// get activation descriptor
cudnnActivationDescriptor_t actDesc;
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
// NOT_PROPAGATE_NAN is requierd by
// cudnnConvolotionBiasActivationForward
switch (op->getAct()) {
case ActType::Relu:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::Sigmoid:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::None:
checkCudnnError(
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
CUDNN_NOT_PROPAGATE_NAN, 0));
break;
default:
assert(false);
}
// get output descriptor
int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, outn, outc,
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc);
}
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
size_t wsSize = record->workspaceSize;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
inData, knDesc, knData, convDesc,
ALGOS[record->algo], wsData, wsSize,
&beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) {
return false;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
// default ctor
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<ConvObj>(_op);
// Both modes have the same performance. Only run cross-correlation.
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
for (int algo = 0; algo < N_ALGO; algo++) {
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
auto &record = *recordRef;
record.mode = mode;
record.algo = algo;
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, recordRef);
// get workspace
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
if (record.workspaceSize > context->getWorkspaceSize()) {
continue;
}
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
record.time = timeit(
[&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
inDesc, inData, knDesc, knData,
convDesc, ALGOS[record.algo],
wsData, record.workspaceSize,
&beta, outDesc, outData);
},
[&]() { context->sync(); });
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Update the tune result
if (ret.time > record.time) {
ret = record;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
}
}
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
// ret.mode);
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return make_ref<ConvCuDnnPerfRecordObj>(ret);
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
auto record = as<ConvCuDnnPerfRecordObj>(_record);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = cuDNNUnfused(op, record, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16,
"Conv_cuDNN_CUDA_Float16");
} // namespace infini

30
src/utils/data_convert.cc Normal file
View File

@ -0,0 +1,30 @@
#include "utils/data_convert.h"
namespace infini {
uint16_t float_to_fp16(const float x) {
Uf32 u;
u.f32 = x;
const uint32_t b = u.u32 + 0x00001000;
const uint32_t e = (b & 0x7F800000) >> 23;
const uint32_t m = b & 0x007FFFFF;
return (b & 0x80000000) >> 16 |
(e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) |
((e < 113) & (e > 101)) *
((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) |
(e > 143) * 0x7FFF;
}
float fp16_to_float(const uint16_t x) {
Uf32 u;
const uint32_t e = (x & 0x7C00) >> 10;
const uint32_t m = (x & 0x03FF) << 13;
u.f32 = (float)m;
const uint32_t v = u.u32 >> 23;
const uint32_t r = (x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) |
((e == 0) & (m != 0)) *
((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000));
u.u32 = r;
return u.f32;
}
} // namespace infini

View File

@ -7,9 +7,9 @@ namespace infini {
TEST(Handler, matmul) {
auto runtime = NativeCpuRuntimeObj::getInstance();
auto handler = make_ref<GraphHandlerObj>(runtime);
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32);
auto i = handler->tensor({1, 2, 3}, DataType::UInt32.getIndex());
auto w = handler->tensor({1, 3, 4}, DataType::UInt32.getIndex());
auto o = handler->tensor({1, 2, 4}, DataType::UInt32.getIndex());
handler->matmul(i, w, o, false, false, nullptr, ActType::None);
}

View File

@ -0,0 +1,78 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/conv.h"
#include <bitset>
#include "test.h"
namespace infini {
void testConvCudnnFP16(
const std::function<void(void *, size_t, DataType)> &generator,
vector<float> ansVec) {
// Construct Runtime and graph for CPU and CUDA
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
Graph gCpu = make_ref<GraphObj>(cpu);
Runtime cuda = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cuda);
// Set input data on CPU in a CPU Graph
Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float16);
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float16);
// Malloc data for all tensors in a graph. Do we need implicit allocation?
gCpu->dataMalloc();
i0Cpu->setData(generator);
w0Cpu->setData(generator);
// Copy input tensors from CPU to CUDA
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
// Build CUDA graph
auto conv =
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2);
// allocate CUDA memory
gCuda->dataMalloc();
// Execute on CUDA
cuda->run(gCuda);
// copy output from CUDA to CPU
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
// check results on CPU
EXPECT_TRUE(o0Cpu->equalData(ansVec));
// print a tensor/operator/graph by print()
gCuda->print();
}
TEST(cuDNN_Conv_FP16, run) {
testConvCudnnFP16(IncrementalGenerator(),
vector<float>{48, 48, 72, 72, 48, 48, 72, 72});
}
TEST(cuDNN_Conv_FP16, tune) {
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
Graph gCpu = make_ref<GraphObj>(cpu);
Runtime cuda = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cuda);
// Set input data on CPU in a CPU Graph
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float16);
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float16);
// Malloc data for all tensors in a graph. Do we need implicit allocation?
gCpu->dataMalloc();
i0Cpu->setData(IncrementalGenerator());
w0Cpu->setData(IncrementalGenerator());
// Copy input tensors from CPU to CUDA
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
// Build CUDA graph
auto conv =
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
// allocate CUDA memory
gCuda->dataMalloc();
// Execute on CUDA
bool tune = true;
cuda->run(gCuda, tune);
}
} // namespace infini