forked from jiuyuan/InfiniTensor
支持fp16 dtype (#96)
* add conv_half kernel * Conv Kernel FP16 * dcj: replace "DataType::Float32" with "op->getDType()" to support more DataType * feat: support Float16 dtype * fix: set default clang-format to 14 version * fix: 按照review意见修改 * fix: add data convert to convfp16 kernel test * test: add conv_fp16 kernel test --------- Co-authored-by: zhangyue207 <zhangyue@qiyuanlab.com> Co-authored-by: kilinchange <kilinchange@163.com>
This commit is contained in:
parent
1dc65e2788
commit
9b10a74788
|
@ -4,19 +4,44 @@ namespace infini {
|
|||
|
||||
class DataType {
|
||||
public:
|
||||
// legacy
|
||||
static const DataType Float32;
|
||||
static const DataType UInt32;
|
||||
// These are just aligned with the type and index of onnx:
|
||||
// <https://onnx.ai/onnx/intro/concepts.html#element-type>
|
||||
static const DataType UInt8, Int8, UInt16, Int16, Int32, Int64;
|
||||
static constexpr size_t sizePerElement[]{
|
||||
sizeof(float), sizeof(uint32_t), sizeof(uint8_t), sizeof(int8_t),
|
||||
sizeof(uint16_t), sizeof(int16_t), sizeof(int32_t), sizeof(int64_t)};
|
||||
static const DataType Undefine;
|
||||
static const DataType Float32;
|
||||
static const DataType UInt8;
|
||||
static const DataType Int8;
|
||||
static const DataType UInt16;
|
||||
static const DataType Int16;
|
||||
static const DataType Int32;
|
||||
static const DataType Int64;
|
||||
static const DataType String;
|
||||
static const DataType Bool;
|
||||
static const DataType Float16;
|
||||
static const DataType Double;
|
||||
static const DataType UInt32;
|
||||
static const DataType UInt64;
|
||||
// "sizePerElement" show the DType to cpu_type
|
||||
// DataType::Bool -> int8_t DataType::Float16 -> uint16_t
|
||||
static constexpr size_t sizePerElement[]{0,
|
||||
sizeof(float),
|
||||
sizeof(uint8_t),
|
||||
sizeof(int8_t),
|
||||
sizeof(uint16_t),
|
||||
sizeof(int16_t),
|
||||
sizeof(int32_t),
|
||||
sizeof(int64_t),
|
||||
sizeof(std::string),
|
||||
sizeof(int8_t),
|
||||
sizeof(uint16_t),
|
||||
sizeof(double),
|
||||
sizeof(uint32_t),
|
||||
sizeof(uint64_t)};
|
||||
|
||||
static constexpr std::string_view names[]{"Float32", "UInt32", "UInt8",
|
||||
"Int8", "UInt16", "Int16",
|
||||
"Int32", "Int64"};
|
||||
static constexpr std::string_view names[]{
|
||||
"Undefine", "Float32", "UInt8", "Int8", "UInt16",
|
||||
"Int16", "Int32", "Int64", "String", "Bool",
|
||||
"Float16", "Double", "UInt32", "UInt64"};
|
||||
|
||||
static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8};
|
||||
|
||||
private:
|
||||
int index;
|
||||
|
@ -29,37 +54,58 @@ class DataType {
|
|||
bool operator==(const DataType &rhs) const { return index == rhs.index; }
|
||||
bool operator<(const DataType &rhs) const { return index < rhs.index; }
|
||||
|
||||
template <typename T> static DataType get() {
|
||||
template <typename T> static int get() {
|
||||
IT_TODO_HALT_MSG("Unsupported data type");
|
||||
}
|
||||
size_t getSize() const { return sizePerElement[index]; }
|
||||
string toString() const { return string(names[index]); }
|
||||
int cpuTypeInt() const { return cpuType[index]; }
|
||||
int getIndex() const { return index; }
|
||||
};
|
||||
|
||||
inline const DataType DataType::Float32(0);
|
||||
inline const DataType DataType::UInt32(1);
|
||||
inline const DataType DataType::UInt8(2), DataType::Int8(3),
|
||||
DataType::UInt16(4), DataType::Int16(5), DataType::Int32(6),
|
||||
DataType::Int64(7);
|
||||
// to be consistent with onnx
|
||||
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
|
||||
inline const DataType DataType::Undefine(0);
|
||||
inline const DataType DataType::Float32(1);
|
||||
inline const DataType DataType::UInt8(2);
|
||||
inline const DataType DataType::Int8(3);
|
||||
inline const DataType DataType::UInt16(4);
|
||||
inline const DataType DataType::Int16(5);
|
||||
inline const DataType DataType::Int32(6);
|
||||
inline const DataType DataType::Int64(7);
|
||||
inline const DataType DataType::String(8);
|
||||
inline const DataType DataType::Bool(9);
|
||||
inline const DataType DataType::Float16(10);
|
||||
inline const DataType DataType::Double(11);
|
||||
inline const DataType DataType::UInt32(12);
|
||||
inline const DataType DataType::UInt64(13);
|
||||
// Method definitions are out of the declaration due to GCC bug:
|
||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||
template <> inline DataType DataType::get<float>() { return Float32; }
|
||||
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
||||
template <> inline DataType DataType::get<uint8_t>() { return UInt8; }
|
||||
template <> inline DataType DataType::get<int8_t>() { return Int8; }
|
||||
template <> inline DataType DataType::get<uint16_t>() { return UInt16; }
|
||||
template <> inline DataType DataType::get<int16_t>() { return Int16; }
|
||||
template <> inline DataType DataType::get<int32_t>() { return Int32; }
|
||||
template <> inline DataType DataType::get<int64_t>() { return Int64; }
|
||||
template <> inline int DataType::get<float>() { return 0; }
|
||||
template <> inline int DataType::get<uint32_t>() { return 1; }
|
||||
template <> inline int DataType::get<uint8_t>() { return 2; }
|
||||
template <> inline int DataType::get<int8_t>() { return 3; }
|
||||
template <> inline int DataType::get<uint16_t>() { return 4; }
|
||||
template <> inline int DataType::get<int16_t>() { return 5; }
|
||||
template <> inline int DataType::get<int32_t>() { return 6; }
|
||||
template <> inline int DataType::get<int64_t>() { return 7; }
|
||||
template <> inline int DataType::get<uint64_t>() { return 8; }
|
||||
template <> inline int DataType::get<double>() { return 9; }
|
||||
|
||||
template <int index> struct DT {};
|
||||
template <> struct DT<0> { using t = float; };
|
||||
template <> struct DT<1> { using t = uint32_t; };
|
||||
template <> struct DT<0> { using t = bool; };
|
||||
template <> struct DT<1> { using t = float; };
|
||||
template <> struct DT<2> { using t = uint8_t; };
|
||||
template <> struct DT<3> { using t = int8_t; };
|
||||
template <> struct DT<4> { using t = uint16_t; };
|
||||
template <> struct DT<5> { using t = int16_t; };
|
||||
template <> struct DT<6> { using t = int32_t; };
|
||||
template <> struct DT<7> { using t = int64_t; };
|
||||
template <> struct DT<8> { using t = char; };
|
||||
template <> struct DT<9> { using t = int8_t; };
|
||||
template <> struct DT<10> { using t = uint16_t; };
|
||||
template <> struct DT<11> { using t = double; };
|
||||
template <> struct DT<12> { using t = uint32_t; };
|
||||
template <> struct DT<13> { using t = uint64_t; };
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -7,30 +7,6 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
// Use the indices from onnx to reduce delivery overhead,
|
||||
// which comes from onnx but may be not only used for onnx.
|
||||
//
|
||||
// see https://onnx.ai/onnx/intro/concepts.html#element-type
|
||||
enum OnnxDType : int {
|
||||
UNDEFINED = 0,
|
||||
FLOAT,
|
||||
UINT8,
|
||||
INT8,
|
||||
UINT16,
|
||||
INT16,
|
||||
INT32,
|
||||
INT64,
|
||||
STRING,
|
||||
BOOL,
|
||||
FLOAT16,
|
||||
DOUBLE,
|
||||
UINT32,
|
||||
UINT64,
|
||||
COMPLEX64,
|
||||
COMPLEX128,
|
||||
BFLOAT16,
|
||||
};
|
||||
|
||||
class GraphHandlerObj {
|
||||
Graph g;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
#include "core/tensor_base.h"
|
||||
#include "utils/data_convert.h"
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
|
@ -45,20 +46,20 @@ class TensorObj : public TensorBaseObj {
|
|||
|
||||
// Copy elements from `data`.
|
||||
template <typename T> void copyin(const vector<T> &data) {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
IT_ASSERT(data.size() >= _size);
|
||||
copyin(data.data(), getBytes());
|
||||
}
|
||||
// Copy all the elements to a vector.
|
||||
template <typename T> auto copyout() const {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
std::vector<T> ans(_size);
|
||||
copyout(ans.data(), getBytes());
|
||||
return ans;
|
||||
}
|
||||
// Copy the element at `pos`.
|
||||
template <typename T> auto copyOne(const vector<int> &pos) const {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
auto offset = getOffset(pos);
|
||||
auto bytes = dtype.getSize();
|
||||
T ans;
|
||||
|
@ -98,8 +99,12 @@ class TensorObj : public TensorBaseObj {
|
|||
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
|
||||
|
||||
template <typename T> bool equalData(const vector<T> &dataVector) {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
IT_ASSERT(size() == dataVector.size());
|
||||
if (dtype == DataType::Float16) {
|
||||
return equalDataImpl_fp16(getRawDataPtr<uint16_t *>(),
|
||||
(float *)dataVector.data(), size());
|
||||
}
|
||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
|
||||
}
|
||||
|
||||
|
@ -156,6 +161,20 @@ class TensorObj : public TensorBaseObj {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool equalDataImpl_fp16(const uint16_t *a, const float *b,
|
||||
size_t size) const {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto a_fp32 = fp16_to_float(a[i]);
|
||||
auto b_fp32 = b[i];
|
||||
if (fabs(a_fp32 - b_fp32) / std::max(fabs(a_fp32), fabs(b_fp32)) >
|
||||
1e-6) {
|
||||
printf("Error on %lu: %f %f\n", i, a_fp32, b_fp32);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Shape getPosByOffset(size_t offset, Shape dim) const;
|
||||
size_t getOffsetByPos(Shape pos, Shape dim) const;
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class FlattenObj : public OperatorObj {
|
|||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getAxis() const { return axis; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
#include <iostream>
|
||||
|
||||
namespace infini {
|
||||
union Uf32 {
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
};
|
||||
uint16_t float_to_fp16(const float x);
|
||||
float fp16_to_float(const uint16_t x);
|
||||
} // namespace infini
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
#include "core/tensor_base.h"
|
||||
#include "utils/data_convert.h"
|
||||
#include <random>
|
||||
|
||||
namespace infini {
|
||||
|
@ -10,6 +11,7 @@ class DataGenerator {
|
|||
private:
|
||||
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
|
||||
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
|
||||
virtual void fill_fp16(uint16_t *data, size_t size) { IT_TODO_HALT(); }
|
||||
|
||||
public:
|
||||
virtual ~DataGenerator() {}
|
||||
|
@ -18,6 +20,8 @@ class DataGenerator {
|
|||
fill(reinterpret_cast<uint32_t *>(data), size);
|
||||
else if (dataType == DataType::Float32)
|
||||
fill(reinterpret_cast<float *>(data), size);
|
||||
else if (dataType == DataType::Float16)
|
||||
fill_fp16(reinterpret_cast<uint16_t *>(data), size);
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
@ -38,6 +42,13 @@ class IncrementalGenerator : public DataGenerator {
|
|||
fill<uint32_t>(data, size);
|
||||
}
|
||||
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
||||
// FIXME: fix the accuracy standards when dtype is float16
|
||||
void fill_fp16(uint16_t *data, size_t size) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float x = 2.0f;
|
||||
data[i] = float_to_fp16(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class RandomGenerator : public DataGenerator {
|
||||
|
|
|
@ -477,7 +477,10 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
# NOTE(constroy): `axes` is an attribute until opset version 13.
|
||||
next((attr.ints for attr in node.attribute if attr.name == "axes"), None),
|
||||
next(
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||
None,
|
||||
),
|
||||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||
!= 0,
|
||||
)
|
||||
|
@ -531,6 +534,12 @@ class OnnxStub:
|
|||
obj.copyin_int64(_parse_data(tensor))
|
||||
elif tensor.data_type == TensorProto.FLOAT:
|
||||
obj.copyin_float(_parse_data(tensor))
|
||||
elif tensor.data_type == TensorProto.BOOL:
|
||||
obj.copyin_int8(_parse_data(tensor))
|
||||
elif tensor.data_type == TensorProto.FLOAT16:
|
||||
obj.copyin_float16(_parse_data_fp16(tensor))
|
||||
elif tensor.data_type == TensorProto.INT8:
|
||||
obj.copyin_uint8(_parse_data(tensor))
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
|
@ -730,7 +739,8 @@ class OnnxStub:
|
|||
]:
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
axis = backend.flatten_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.Transpose:
|
||||
perm = backend.transpose_permute_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, perm=perm))
|
||||
|
@ -894,5 +904,21 @@ def _parse_data(tensor: TensorProto) -> List[Any]:
|
|||
return to_array(tensor).flatten().tolist()
|
||||
|
||||
|
||||
def _parse_data_fp16(tensor: TensorProto):
|
||||
list_ = []
|
||||
if len(tensor.int32_data) != 0:
|
||||
for element_data in tensor.int32_data:
|
||||
element_byte = element_data.to_bytes(2, "little")
|
||||
list_.append(element_byte[0] + element_byte[1] * 256)
|
||||
elif len(tensor.raw_data) != 0:
|
||||
list_raw_data = list(tensor.raw_data)
|
||||
list_data = [list_raw_data[i : i + 2] for i in range(0, len(list_raw_data), 2)]
|
||||
for ele in list_data:
|
||||
list_.append(ele[0] + ele[1] * 256)
|
||||
else:
|
||||
raise Exception("Tensor have no float16 data!")
|
||||
return list_
|
||||
|
||||
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
|
|
@ -64,6 +64,21 @@ class TestStringMethods(unittest.TestCase):
|
|||
)
|
||||
make_and_import_model(make_graph([conv], "conv", [i, w], [o]))
|
||||
|
||||
def test_conv_fp16(self):
|
||||
i = make_tensor_value_info("i", TensorProto.FLOAT16, [1, 3, 4, 4])
|
||||
w = make_tensor_value_info("w", TensorProto.FLOAT16, [2, 3, 3, 3])
|
||||
o = make_tensor_value_info("o", TensorProto.FLOAT16, [1, 2, 2, 2])
|
||||
conv = make_node(
|
||||
"Conv",
|
||||
["i", "w"],
|
||||
["o"],
|
||||
"conv",
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[2, 1],
|
||||
dilations=[1, 2],
|
||||
)
|
||||
make_and_import_model(make_graph([conv], "conv_fp16", [i, w], [o]))
|
||||
|
||||
def test_matmul(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||
|
@ -211,9 +226,9 @@ class TestStringMethods(unittest.TestCase):
|
|||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3, 5 * 7])
|
||||
flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten")
|
||||
# make_and_import_model(
|
||||
make_graph([flatten], "flatten", [x], [y])
|
||||
# )
|
||||
make_and_import_model(
|
||||
make_graph([flatten], "flatten", [x], [y])
|
||||
)
|
||||
|
||||
def test_reshape(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||
|
|
|
@ -13,8 +13,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType(), DataType::Float32};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -292,23 +292,35 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
|
|||
}
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
case OnnxDType::FLOAT:
|
||||
switch (dtype) {
|
||||
case 0:
|
||||
return DataType::Undefine;
|
||||
case 1:
|
||||
return DataType::Float32;
|
||||
case OnnxDType::UINT32:
|
||||
return DataType::UInt32;
|
||||
case OnnxDType::UINT8:
|
||||
case 2:
|
||||
return DataType::UInt8;
|
||||
case OnnxDType::INT8:
|
||||
case 3:
|
||||
return DataType::Int8;
|
||||
case OnnxDType::UINT16:
|
||||
case 4:
|
||||
return DataType::UInt16;
|
||||
case OnnxDType::INT16:
|
||||
case 5:
|
||||
return DataType::Int16;
|
||||
case OnnxDType::INT32:
|
||||
case 6:
|
||||
return DataType::Int32;
|
||||
case OnnxDType::INT64:
|
||||
case 7:
|
||||
return DataType::Int64;
|
||||
case 8:
|
||||
return DataType::String;
|
||||
case 9:
|
||||
return DataType::Bool;
|
||||
case 10:
|
||||
return DataType::Float16;
|
||||
case 11:
|
||||
return DataType::Double;
|
||||
case 12:
|
||||
return DataType::UInt32;
|
||||
case 13:
|
||||
return DataType::UInt64;
|
||||
default:
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
|
|
@ -110,8 +110,6 @@ optional<vector<Shape>> OperatorObj::inferShape() const {
|
|||
|
||||
vector<DataType> OperatorObj::inferDataType(const TensorVec &inputs) const {
|
||||
auto dataType = inputs[0]->getDType();
|
||||
for (const auto &tensor : inputs)
|
||||
IT_ASSERT(dataType == tensor->getDType());
|
||||
return vector(numOutputs(), dataType);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,9 +30,7 @@ void from_json(const json &j, OpPerfKey &p) {
|
|||
j.at("opType").get_to(p.opType);
|
||||
j.at("attrs").get_to(p.attrs);
|
||||
}
|
||||
void to_json(json &j, const DataType &p) {
|
||||
j = p.toString() == "Float32" ? 0 : 1;
|
||||
}
|
||||
void to_json(json &j, const DataType &p) { j = p.getIndex(); }
|
||||
void from_json(const json &j, DataType &p) { p = DataType(j.get<int>()); }
|
||||
void to_json(json &j, const PerfRecord &p) { p->to_json(j); }
|
||||
void from_json(const json &j, PerfRecord &p) {
|
||||
|
@ -49,4 +47,4 @@ void from_json(const json &j, PerfEngine &p) {
|
|||
p.set_data(tmp);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -71,14 +71,20 @@ void TensorObj::printData() const {
|
|||
if (dtype == DataType(N)) \
|
||||
std::cout << dataToString<DT<N>::t>() << std::endl;
|
||||
|
||||
TRY_PRINT(0) // fmt: new line
|
||||
else TRY_PRINT(1) //
|
||||
else TRY_PRINT(2) //
|
||||
else TRY_PRINT(3) //
|
||||
else TRY_PRINT(4) //
|
||||
else TRY_PRINT(5) //
|
||||
else TRY_PRINT(6) //
|
||||
else TRY_PRINT(7) //
|
||||
TRY_PRINT(0) // fmt: new line
|
||||
else TRY_PRINT(1) //
|
||||
else TRY_PRINT(2) //
|
||||
else TRY_PRINT(3) //
|
||||
else TRY_PRINT(4) //
|
||||
else TRY_PRINT(5) //
|
||||
else TRY_PRINT(6) //
|
||||
else TRY_PRINT(7) //
|
||||
else TRY_PRINT(8) //
|
||||
else TRY_PRINT(9) //
|
||||
else TRY_PRINT(10) //
|
||||
else TRY_PRINT(11) //
|
||||
else TRY_PRINT(12) //
|
||||
else TRY_PRINT(13) //
|
||||
else IT_TODO_HALT();
|
||||
|
||||
#undef TRY_PRINT
|
||||
|
@ -98,14 +104,20 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
|
|||
return equalDataImpl(getRawDataPtr<DT<N>::t *>(), \
|
||||
rhs->getRawDataPtr<DT<N>::t *>(), size());
|
||||
|
||||
TEST_EQUAL(0) // fmt: new line
|
||||
else TEST_EQUAL(1) //
|
||||
else TEST_EQUAL(2) //
|
||||
else TEST_EQUAL(3) //
|
||||
else TEST_EQUAL(4) //
|
||||
else TEST_EQUAL(5) //
|
||||
else TEST_EQUAL(6) //
|
||||
else TEST_EQUAL(7) //
|
||||
TEST_EQUAL(0) // fmt: new line
|
||||
else TEST_EQUAL(1) //
|
||||
else TEST_EQUAL(2) //
|
||||
else TEST_EQUAL(3) //
|
||||
else TEST_EQUAL(4) //
|
||||
else TEST_EQUAL(5) //
|
||||
else TEST_EQUAL(6) //
|
||||
else TEST_EQUAL(7) //
|
||||
else TEST_EQUAL(8) //
|
||||
else TEST_EQUAL(9) //
|
||||
else TEST_EQUAL(10) //
|
||||
else TEST_EQUAL(11) //
|
||||
else TEST_EQUAL(12) //
|
||||
else TEST_EQUAL(13) //
|
||||
else IT_TODO_HALT();
|
||||
|
||||
#undef TEST_EQUAL
|
||||
|
|
|
@ -11,8 +11,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
auto &perfEngine = PerfEngine::getInstance();
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType(), DataType::Float32};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -33,8 +32,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType(), DataType::Float32};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -100,22 +100,34 @@ void export_values(py::module &m) {
|
|||
}
|
||||
|
||||
static int tensor_dtype(Tensor t) {
|
||||
if (t->getDType() == DataType::Undefine)
|
||||
return 0;
|
||||
if (t->getDType() == DataType::Float32)
|
||||
return OnnxDType::FLOAT;
|
||||
if (t->getDType() == DataType::UInt32)
|
||||
return OnnxDType::UINT32;
|
||||
return 1;
|
||||
if (t->getDType() == DataType::UInt8)
|
||||
return OnnxDType::UINT8;
|
||||
return 2;
|
||||
if (t->getDType() == DataType::Int8)
|
||||
return OnnxDType::INT8;
|
||||
return 3;
|
||||
if (t->getDType() == DataType::UInt16)
|
||||
return OnnxDType::UINT16;
|
||||
return 4;
|
||||
if (t->getDType() == DataType::Int16)
|
||||
return OnnxDType::INT16;
|
||||
return 5;
|
||||
if (t->getDType() == DataType::Int32)
|
||||
return OnnxDType::INT32;
|
||||
return 6;
|
||||
if (t->getDType() == DataType::Int64)
|
||||
return OnnxDType::INT64;
|
||||
return 7;
|
||||
if (t->getDType() == DataType::String)
|
||||
return 8;
|
||||
if (t->getDType() == DataType::Bool)
|
||||
return 9;
|
||||
if (t->getDType() == DataType::Float16)
|
||||
return 10;
|
||||
if (t->getDType() == DataType::Double)
|
||||
return 11;
|
||||
if (t->getDType() == DataType::UInt32)
|
||||
return 12;
|
||||
if (t->getDType() == DataType::UInt64)
|
||||
return 13;
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
|
@ -224,6 +236,11 @@ static vector<int> transpose_permute_of(Operator op) {
|
|||
return dynamic_cast<const TransposeObj *>(op.get())->getPermute();
|
||||
}
|
||||
|
||||
static int flatten_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Flatten);
|
||||
return dynamic_cast<const FlattenObj *>(op.get())->getAxis();
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
|
@ -252,7 +269,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(transpose_permute_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(split_axis_of)
|
||||
.FUNCTION(gather_axis_of);
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(flatten_axis_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
@ -276,9 +294,15 @@ void init_graph_builder(py::module &m) {
|
|||
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
|
||||
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
|
||||
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
|
||||
.def("copyin_int8", &TensorObj::copyin<int8_t>, policy::move)
|
||||
.def("copyin_uint8", &TensorObj::copyin<uint8_t>, policy::move)
|
||||
.def("copyin_float16", &TensorObj::copyin<uint16_t>, policy::move)
|
||||
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
|
||||
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
|
||||
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
|
||||
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
|
||||
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
|
||||
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic);
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
|
||||
namespace infini {
|
||||
|
||||
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
||||
int mode = 1;
|
||||
size_t workspaceSize = 100000;
|
||||
bool fuseAct = false;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 1;
|
||||
j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize);
|
||||
}
|
||||
static PerfRecord from_json(const json &j) {
|
||||
ConvCuDnnPerfRecordObj tmp;
|
||||
auto [Algo, Mode, FuseAct, Time, WorkspaceSize] =
|
||||
j["data"].get<tuple<int, int, bool, double, size_t>>();
|
||||
tmp.algo = Algo;
|
||||
tmp.mode = Mode;
|
||||
tmp.fuseAct = FuseAct;
|
||||
tmp.time = Time;
|
||||
tmp.workspaceSize = WorkspaceSize;
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
|
||||
|
||||
class convCudnnFP16 : public Kernel {
|
||||
|
||||
static constexpr int N_ALGO = 8;
|
||||
static constexpr int N_MODE = 2;
|
||||
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
|
||||
|
||||
static constexpr cudnnConvolutionMode_t MODES[2] = {
|
||||
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||
|
||||
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
|
||||
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
cudnnTensorDescriptor_t>
|
||||
createCuDNNDescriptor(const Ref<ConvObj> &op,
|
||||
const ConvCuDnnPerfRecord &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
// Bias is not supported yet
|
||||
if (op->getInputs().size() > 2) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
||||
int channelsPerGrp = cpg, channels = c;
|
||||
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_HALF, n, channels,
|
||||
h, w)); /*fp16 type*/
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(
|
||||
knDesc, CUDNN_DATA_HALF, /*fp16 type*/
|
||||
CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
|
||||
// get bias
|
||||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_HALF, 1, f, 1,
|
||||
1)); /*fp16 type*/
|
||||
|
||||
// get convolution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
||||
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
|
||||
CUDNN_DATA_HALF)); /*fp16 type*/
|
||||
if (g > 1) {
|
||||
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
||||
}
|
||||
|
||||
// get activation descriptor
|
||||
cudnnActivationDescriptor_t actDesc;
|
||||
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
|
||||
// NOT_PROPAGATE_NAN is requierd by
|
||||
// cudnnConvolotionBiasActivationForward
|
||||
switch (op->getAct()) {
|
||||
case ActType::Relu:
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
case ActType::Sigmoid:
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
case ActType::None:
|
||||
checkCudnnError(
|
||||
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
|
||||
CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// get output descriptor
|
||||
int outn, outc, outh, outw;
|
||||
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
|
||||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_HALF, outn, outc,
|
||||
outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
}
|
||||
|
||||
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
size_t wsSize = record->workspaceSize;
|
||||
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
|
||||
inData, knDesc, knData, convDesc,
|
||||
ALGOS[record->algo], wsData, wsSize,
|
||||
&beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
|
||||
// default ctor
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
|
||||
auto &record = *recordRef;
|
||||
record.mode = mode;
|
||||
record.algo = algo;
|
||||
cudnnStatus_t stat;
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, recordRef);
|
||||
|
||||
// get workspace
|
||||
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
|
||||
ALGOS[record.algo], &record.workspaceSize);
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
continue;
|
||||
}
|
||||
if (record.workspaceSize > context->getWorkspaceSize()) {
|
||||
continue;
|
||||
}
|
||||
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionForward(
|
||||
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||
knData, convDesc, ALGOS[record.algo], wsData,
|
||||
record.workspaceSize, &beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
continue;
|
||||
}
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
|
||||
inDesc, inData, knDesc, knData,
|
||||
convDesc, ALGOS[record.algo],
|
||||
wsData, record.workspaceSize,
|
||||
&beta, outDesc, outData);
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time) {
|
||||
ret = record;
|
||||
}
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
}
|
||||
}
|
||||
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
|
||||
// ret.mode);
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto record = as<ConvCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, record, context);
|
||||
IT_ASSERT(success);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16,
|
||||
"Conv_cuDNN_CUDA_Float16");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,30 @@
|
|||
#include "utils/data_convert.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
uint16_t float_to_fp16(const float x) {
|
||||
Uf32 u;
|
||||
u.f32 = x;
|
||||
const uint32_t b = u.u32 + 0x00001000;
|
||||
const uint32_t e = (b & 0x7F800000) >> 23;
|
||||
const uint32_t m = b & 0x007FFFFF;
|
||||
return (b & 0x80000000) >> 16 |
|
||||
(e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) |
|
||||
((e < 113) & (e > 101)) *
|
||||
((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) |
|
||||
(e > 143) * 0x7FFF;
|
||||
}
|
||||
|
||||
float fp16_to_float(const uint16_t x) {
|
||||
Uf32 u;
|
||||
const uint32_t e = (x & 0x7C00) >> 10;
|
||||
const uint32_t m = (x & 0x03FF) << 13;
|
||||
u.f32 = (float)m;
|
||||
const uint32_t v = u.u32 >> 23;
|
||||
const uint32_t r = (x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) |
|
||||
((e == 0) & (m != 0)) *
|
||||
((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000));
|
||||
u.u32 = r;
|
||||
return u.f32;
|
||||
}
|
||||
} // namespace infini
|
|
@ -7,9 +7,9 @@ namespace infini {
|
|||
TEST(Handler, matmul) {
|
||||
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||
auto handler = make_ref<GraphHandlerObj>(runtime);
|
||||
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
|
||||
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
|
||||
auto o = handler->tensor({1, 2, 4}, OnnxDType::UINT32);
|
||||
auto i = handler->tensor({1, 2, 3}, DataType::UInt32.getIndex());
|
||||
auto w = handler->tensor({1, 3, 4}, DataType::UInt32.getIndex());
|
||||
auto o = handler->tensor({1, 2, 4}, DataType::UInt32.getIndex());
|
||||
handler->matmul(i, w, o, false, false, nullptr, ActType::None);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
#include <bitset>
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvCudnnFP16(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 4, 4}, DataType::Float16);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float16);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv_FP16, run) {
|
||||
testConvCudnnFP16(IncrementalGenerator(),
|
||||
vector<float>{48, 48, 72, 72, 48, 48, 72, 72});
|
||||
}
|
||||
|
||||
TEST(cuDNN_Conv_FP16, tune) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float16);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float16);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
}
|
||||
} // namespace infini
|
Loading…
Reference in New Issue