Modify kernel registration & support fp16 (#205)

* - Remove dataType from the kernel registration.

* - support fp16 for conv

* - cpu kernel: adapt the new registration mechanism

* modified all register kernel

* add where fp16

* add layernorm fp16

* add split_concat fp16

* - element_wise support fp16

* feat: support transpose fp16

* feat: support sliceOp fp16

* - unary support fp16

* - feat: support reduceOp fp16

* feat: support matmulOp/expandOp fp16

* feat: support powOp int8

* add cuda cast & support half-precision for gather

* style: fix style

* feat:support int8 for gather

* style:fix style

* modified test_cuda_conv_transposed

* fix: fix dist code to support fp16

* fix(graph.cc): fix topo_sort

* fix: fix recv and send kernel registration

* feat: add field tensors for stub

* refactor(frontend): 先排序后构图

Signed-off-by: YdrMaster <ydrml@hotmail.com>

* fix: 为中间结果提供tensor到node的mapping

* fix (slice): add guard for area out of range

* fix: fix matmul fp16

* fix: fix re-dataMalloc for weight tensor and use of naive allocator

* feat: add dataType filter for cuda kernel

* feat: bang kernel adapt the new registration mechanism

* fix: fix some error on mlu

* feat: intelcpu kernel adapt the new registration mechanism

* feat: modify kernel registration on kunlun

* fix intelcpu compiler bug

* feat: bang reshape support all dataType

* fix: fix bang reduce

* fix(all_reduce.cc): fix as reviewer suggessted

* fix: fix style and restore unary test codes

---------

Signed-off-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: xgqdut2016 <kenan_gewei@163.com>
Co-authored-by: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com>
Co-authored-by: zhangyunze <z13785159769@163.com>
Co-authored-by: OdinaryWord <sx-hz@163.com>
Co-authored-by: YdrMaster <ydrml@hotmail.com>
Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
Chenjie Duan 2024-01-15 11:02:13 +08:00 committed by GitHub
parent 58993d4339
commit 51086d2b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
157 changed files with 3627 additions and 2575 deletions

View File

@ -137,7 +137,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
place[node.output[0]] = Shard(list(perm).index(plc.dim))
def shard_node(node: NodeProto):
if node.op_type in ["Relu", "Tanh", "Softmax"]:
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
place[node.output[0]] = place[node.input[0]]
elif node.op_type in ["Where"]:
place[node.output[0]] = place[node.input[1]]
@ -177,7 +177,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
input in data for input in node.input
):
# FIXME(constroy): the last MatMul should not be sharded as TP.
if node.output[0] in output:
if (
node.output[0] in output
or (
index + 1 < len(model.graph.node)
and model.graph.node[index + 1].output[0]
)
in output
):
continue
groups = 1
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups

View File

@ -30,7 +30,6 @@ class Kernel {
public:
Kernel() {}
virtual ~Kernel() {}
/**
* @param op The operator to be executed.
* @param record The parameters for kernel execution. If extra parameters
@ -130,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel {
} // namespace infini
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \
namespace infini { \
static const bool _CAT(_register_kernel_, cnt) = \
KernelRegistry::getInstance().registerKernel( \
KernelAttrs{device, opType, dataType}, new kernel(), name); \
KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \
opType}, \
new kernel(), name); \
}
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
#define REGISTER_KERNEL(device, opType, kernel, name) \
_REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__)
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
namespace infini { \

View File

@ -4,7 +4,7 @@
#include "core/tensor.h"
namespace infini {
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
struct OpPerfKey {
HashType hash;
@ -90,6 +90,7 @@ class OperatorObj : public Object {
OpType getOpType() const { return type; }
// HACK: set correct data type
DataType getDType() const { return getInputs(0)->getDType(); }
DataType getOutDType() const { return getOutput()->getDType(); }
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;

View File

@ -44,8 +44,16 @@ class TensorObj : public TensorBaseObj {
bool isOutput() const { return tensorType == TensorType::output; }
bool isOthers() const { return tensorType == TensorType::others; }
void setWeight() { tensorType = TensorType::weight; }
void setInput() { tensorType = TensorType::input; }
void setOutput() { tensorType = TensorType::output; }
void setInput() {
if (!this->isWeight()) {
tensorType = TensorType::input;
}
}
void setOutput() {
if (!this->isWeight()) {
tensorType = TensorType::output;
}
}
string tensorTypeToString() const {
switch (tensorType) {
case TensorType::weight:

View File

@ -1,13 +1,16 @@
#pragma once
namespace infini {
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3);
void div_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3);
void add_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3);
void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3);
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3);
}; // namespace infini

View File

@ -3,7 +3,8 @@
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape);
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape);
}; // namespace infini

View File

@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
void LaynormKernel(const float *input, const float *scale, const float eps,
int size, int scaleSize, const int dimsize, const int stride,
float *output);
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output, const half *bias, int biasSize);
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output);
}; // namespace infini

View File

@ -3,4 +3,6 @@
namespace infini {
void softmax_kernel(int num_blocks, float *input, float *output, int size,
int dimsize, int stride);
}
void softmax_kernel(int num_blocks, half *input, half *output, int size,
int dimsize, int stride);
} // namespace infini

View File

@ -8,8 +8,8 @@ const int DIM_MAX_SIZE = 8;
// Concat operator acts like element tensors composing to one big tensor,and
// split operator acts like one big tensor being composed by element
// tensors.
struct ElementTensorMetadata {
float *data[BATCH_SIZE];
template <typename T> struct ElementTensorMetadata {
T *data[BATCH_SIZE];
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
// the composed tensor.
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
@ -20,16 +20,17 @@ struct ElementTensorMetadata {
data[i], dimBgNo[i], dimSize[i], nElements[i]);
}
};
struct ComposedTensorMetadata {
template <typename T> struct ComposedTensorMetadata {
int dimSize[DIM_MAX_SIZE];
int stride[DIM_MAX_SIZE];
float *data;
T *data;
};
namespace infini {
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
const ComposedTensorMetadata &compMeta, int dim,
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
const ComposedTensorMetadata<float> &compMeta, int dim,
int batchSize, int nDims, bool isSplit);
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
const ComposedTensorMetadata<half> &compMeta, int dim,
int batchSize, int nDims, bool isSplit);
} // namespace infini

View File

@ -5,7 +5,7 @@
namespace infini {
void transpose_kernel(float *input, float *output, int nDims, int size,
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
SmallArray strides, SmallArray outputShape);
}; // namespace infini

View File

@ -3,48 +3,21 @@
#include "operators/unary.h"
namespace infini {
void softmax_kernel(float *input, float *output, size_t num);
void relu_kernel(float *input, float *output, size_t num);
void sigmoid_kernel(float *input, float *output, size_t num);
void tanh_kernel(float *input, float *output, size_t num);
void abs_kernel(float *input, float *output, size_t num);
void sqrt_kernel(float *input, float *output, size_t num);
void neg_kernel(float *input, float *output, size_t num);
void gelu_kernel(float *input, float *output, size_t num);
void erf_kernel(float *input, float *output, size_t num);
void hard_sigmoid_kernel(float *input, float *output, size_t num);
void hard_swish_kernel(float *input, float *output, size_t num);
template <typename T> void softmax_kernel(T *input, T *output, size_t num);
template <typename T> void relu_kernel(T *input, T *output, size_t num);
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
template <typename T> void tanh_kernel(T *input, T *output, size_t num);
template <typename T> void abs_kernel(T *input, T *output, size_t num);
template <typename T> void sqrt_kernel(T *input, T *output, size_t num);
template <typename T> void neg_kernel(T *input, T *output, size_t num);
template <typename T> void gelu_kernel(T *input, T *output, size_t num);
template <typename T> void erf_kernel(T *input, T *output, size_t num);
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op);
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
template <typename INPUT, typename OUTPUT>
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
size_t num = op->getOutput()->size();
if (op->getOpType() == OpType::Softmax)
softmax_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Relu)
relu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sigmoid)
sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::HardSigmoid)
hard_sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::HardSwish)
hard_swish_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Tanh)
tanh_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Abs)
abs_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sqrt)
sqrt_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Gelu)
gelu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Neg)
neg_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Erf)
erf_kernel(inputData, outputData, num);
else
IT_TODO_HALT();
}
void unary_kernel(const Operator &_op);
}; // namespace infini

View File

@ -1,11 +1,29 @@
#pragma once
#include "core/tensor.h"
#include "cuda/cuda_common.h"
namespace infini {
void cudaPrintFloat(float *x, int len);
void cudaPrintTensor(const Tensor &tensor) {
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
}
void cudaPrintTensor(const Tensor &tensor);
cudnnDataType_t cudnnDataTypeConvert(DataType dataType);
cudaDataType cublasDataTypeConvert(DataType);
template <int index> struct DT_CUDA {};
template <> struct DT_CUDA<0> { using t = bool; };
template <> struct DT_CUDA<1> { using t = float; };
template <> struct DT_CUDA<2> { using t = unsigned char; };
template <> struct DT_CUDA<3> { using t = char; };
template <> struct DT_CUDA<4> { using t = unsigned short; };
template <> struct DT_CUDA<5> { using t = short; };
template <> struct DT_CUDA<6> { using t = int; };
template <> struct DT_CUDA<7> { using t = long long; };
template <> struct DT_CUDA<9> { using t = bool; };
template <> struct DT_CUDA<10> { using t = half; };
template <> struct DT_CUDA<11> { using t = double; };
template <> struct DT_CUDA<12> { using t = unsigned int; };
template <> struct DT_CUDA<13> { using t = unsigned long long; };
template <> struct DT_CUDA<16> { using t = nv_bfloat16; };
} // namespace infini

View File

@ -3,10 +3,15 @@
#include "utils/small_array.h"
namespace infini {
void whereKernel(const float *inputX, const float *inputY,
const uint8_t *condition, float *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
void whereKernel(const half *inputX, const half *inputY,
const uint8_t *condition, half *output, int nDims,
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
SmallArray conditionShape, SmallArray outputShape, int xSize,
int ySize, int cSize);
}; // namespace infini

View File

@ -53,7 +53,8 @@ inline void initGatherMetaData(GatherMetaData &metaData,
metaData.inStride[i] = in->getStride()[i];
}
}
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
template <typename T>
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num);
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
size_t num);

View File

@ -91,6 +91,12 @@ template <int val> class ValGenerator : public DataGenerator {
fill<uint32_t>(data, size);
}
void fill(float *data, size_t size) override { fill<float>(data, size); }
void fill_fp16(uint16_t *data, size_t size) {
for (size_t i = 0; i < size; i++) {
float x = 1.0f * val;
data[i] = float_to_fp16(x);
}
}
};
typedef ValGenerator<1> OneGenerator;
typedef ValGenerator<0> ZeroGenerator;

View File

@ -37,7 +37,7 @@ class OnnxStub:
It can be generated from an Onnx model object.
"""
def __init__(self, model: ModelProto, runtime):
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
# We use some user-defined operators for distributed inference
try:
# onnx simplifier performs inplace simplify
@ -51,13 +51,43 @@ class OnnxStub:
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.tensors: Dict[str, backend.Tensor] = {}
self.tensor_node_map: Dict[str, str] = {}
self.initializer: Dict[int, TensorProto] = {}
self.use_naive_allocator: bool = use_naive_allocator
# try:
# model = infer_shapes(model)
# except:
# warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime)
# 处理重名和匿名算子
names = {}
for node in model.graph.node:
if node.name == "":
node.name = "missing_name(" + node.op_type + ")"
if node.name in names:
names[node.name] += 1
node.name += "_" + str(names[node.name])
else:
names[node.name] = 0
# 拓扑排序
sorted_nodes = []
known_edge = set(t.name for t in model.graph.input)
known_edge.update(t.name for t in model.graph.initializer)
while len(sorted_nodes) < len(model.graph.node):
updated = False
for i, node in enumerate(model.graph.node):
if all(t in known_edge for t in node.input):
node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i)
known_edge.update(node.output)
for t_ in node.output:
self.tensor_node_map[t_] = node.name
updated = True
if not updated:
raise Exception("Graph has cycle")
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
@ -82,17 +112,8 @@ class OnnxStub:
)
tensors[output.name].set_output()
node_name = []
new_node_name = []
for node in model.graph.node:
node_name.append(node.name)
node_list = model.graph.node
while len(node_list) != 0:
for node in model.graph.node:
if node.name not in node_list:
continue
if _analyse_node(node, tensors):
continue
for node_idx in sorted_nodes:
node = model.graph.node[node_idx]
if node.op_type == "Conv":
attributes = _parse_attribute(
node,
@ -200,8 +221,7 @@ class OnnxStub:
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
)
(alpha, beta, transA, transB) = (
attributes[name]
for name in ["alpha", "beta", "transA", "transB"]
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
)
# FIXME unsupport attributes: `alpha` `beta`
assert alpha == 1.0
@ -637,9 +657,7 @@ class OnnxStub:
tensors[node.output[0]] = self.handler.concat(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
next(
(attr.i for attr in node.attribute if attr.name == "axis")
),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "AttentionKVCache":
tensors[node.output[0]] = self.handler.attentionKVCache(
@ -709,19 +727,11 @@ class OnnxStub:
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
next(
(
attr.ints
for attr in node.attribute
if attr.name == "axes"
),
(attr.ints for attr in node.attribute if attr.name == "axes"),
None,
),
next(
(
attr.i
for attr in node.attribute
if attr.name == "keepdims"
),
(attr.i for attr in node.attribute if attr.name == "keepdims"),
1,
)
!= 0,
@ -749,9 +759,7 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[3]])
if len(node.input) > 3
else None,
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
)
elif node.op_type == "Dropout":
for name, tensor in zip(
@ -759,9 +767,7 @@ class OnnxStub:
self.handler.dropout(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors.get(node.output[1])
if len(node.output) > 1
else None,
tensors.get(node.output[1]) if len(node.output) > 1 else None,
_parse_data(data[node.input[1]])[0]
if len(node.input) > 1
else 0.5,
@ -865,11 +871,7 @@ class OnnxStub:
0,
)
destination = next(
(
attr.i
for attr in node.attribute
if attr.name == "destination"
),
(attr.i for attr in node.attribute if attr.name == "destination"),
0,
)
@ -885,11 +887,7 @@ class OnnxStub:
0,
)
destination = next(
(
attr.i
for attr in node.attribute
if attr.name == "destination"
),
(attr.i for attr in node.attribute if attr.name == "destination"),
0,
)
@ -943,7 +941,8 @@ class OnnxStub:
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
)
(alpha, beta, bias, size) = (
attributes[name] for name in ["alpha", "beta", "bias", "size"]
attributes[name]
for name in ["alpha", "beta", "bias", "size"]
)
tensors[node.output[0]] = self.handler.lrn(
tensors[node.input[0]],
@ -955,14 +954,11 @@ class OnnxStub:
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name)
# update the node_list
node_list = list(set(node_name) - set(new_node_name))
################################
# Allocate memory space for data
################################
self.handler.data_malloc()
self.handler.data_malloc(self.use_naive_allocator)
#################################
# Copy in data to tensor objects
@ -993,6 +989,9 @@ class OnnxStub:
# assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
obj.copyin_numpy(to_array(tensor))
for name, obj in tensors.items():
self.tensors[name] = obj
for output in model.graph.output:
self.outputs[output.name] = tensors[output.name]
@ -1335,7 +1334,7 @@ class OnnxStub:
return ctx.build(name)
def init(self) -> None:
self.handler.data_malloc()
self.handler.data_malloc(self.use_naive_allocator)
def optimize(self) -> None:
self.handler.optimize()
@ -1351,7 +1350,7 @@ class OnnxStub:
oldTensor = self.inputs[oldInput]
self.handler.change_shape(newInput, oldTensor.fuid())
self.handler.shape_infer()
self.handler.data_malloc()
self.handler.data_malloc(self.use_naive_allocator)
def getShape(self, name: str) -> List[int]:
if name in self.inputs:
@ -1414,10 +1413,3 @@ def _parse_data_fp16(tensor: TensorProto):
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
def _analyse_node(node: NodeProto, tensors) -> bool:
for i in node.input:
if i not in tensors:
return True
return False

View File

@ -16,8 +16,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -87,48 +87,33 @@ string GraphObj::toString() const {
}
bool GraphObj::topo_sort() {
if (this->sorted)
if (this->sorted) {
return true;
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
}
std::vector<Operator> sorted;
while (!waiting.empty()) {
std::unordered_set<OperatorObj *> flags;
sorted.reserve(ops.size());
flags.reserve(ops.size());
while (sorted.size() < ops.size()) {
// Any node is move to sorted in this loop.
auto modified = false;
// Find head nodes.
for (auto it = waiting.begin(); it != waiting.end();) {
const auto &this_inputs = (*it)->getInputs();
// If none of the input tensors is in waiting list,
// this node is a head node.
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getSource();
return src // If the source node is in the waiting
// list, means that this node is not the
// head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
for (auto const &op : ops) {
if (auto const &inputs = op->getInputs();
flags.find(op.get()) == flags.end() &&
std::all_of(inputs.begin(), inputs.end(),
[&flags](auto const &input) {
auto ptr = input->getSource().get();
return !ptr || flags.find(ptr) != flags.end();
})) {
modified = true;
sorted.emplace_back(std::move(*it));
it = waiting.erase(it);
} else {
++it;
sorted.emplace_back(op);
flags.insert(op.get());
}
}
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
}
@ -182,8 +167,11 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
// note: behavior may not match running in non-naive mode, and it may
// not reproduce the bug
for (auto &tensor : tensors) {
if (!tensor->isWeight() ||
(tensor->isWeight() && !weightAllocated)) {
tensor->dataMalloc();
}
}
return;
}
if (memPoolSize > 0) {

View File

@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
auto &perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);
@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
DataType::Float32};
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
auto perfData = perfEngine.getPerfData(perfKey);

View File

@ -1,4 +1,6 @@
#include "core/data_type.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include <cstdio>
__global__ void cudaPrintFloatImpl(float *x, int len) {
@ -18,4 +20,55 @@ void cudaPrintFloat(float *x, int len) {
cudaDeviceSynchronize();
}
void cudaPrintTensor(const Tensor &tensor) {
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
}
cudnnDataType_t cudnnDataTypeConvert(DataType dataType) {
if (dataType == DataType::Float32) {
return CUDNN_DATA_FLOAT;
}
if (dataType == DataType::Double) {
return CUDNN_DATA_DOUBLE;
}
if (dataType == DataType::Float16) {
return CUDNN_DATA_HALF;
}
if (dataType == DataType::Int8) {
return CUDNN_DATA_INT8;
}
if (dataType == DataType::Int32) {
return CUDNN_DATA_INT32;
}
if (dataType == DataType::UInt8) {
return CUDNN_DATA_UINT8;
}
if (dataType == DataType::BFloat16) {
return CUDNN_DATA_BFLOAT16;
}
if (dataType == DataType::Int64) {
return CUDNN_DATA_INT64;
}
if (dataType == DataType::Bool) {
return CUDNN_DATA_BOOLEAN;
}
IT_ASSERT(false, "Unsupported data type");
}
cudaDataType cublasDataTypeConvert(DataType dataType) {
switch (dataType.getIndex()) {
case 1:
return CUDA_R_32F;
// case 3:
// return CUDA_R_8I;
case 10:
return CUDA_R_16F;
case 11:
return CUDA_R_64F;
// case 16:
// return CUDA_R_16BF;
default:
IT_ASSERT(false, "MatMul Unsupported data type");
}
}
} // namespace infini

View File

@ -11,6 +11,7 @@ class UnaryCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -50,6 +51,7 @@ class RoundCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -80,6 +82,7 @@ class PReluCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PReluObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -119,6 +122,7 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SoftmaxObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -215,15 +219,12 @@ class SigmoidCnnl : public UnaryCnnl {
float getCoef() const override { return 0.0; }
};
REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl,
"Relu_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl,
"PRelu_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
"Sigmoid_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
"Round_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Softmax, DataType::Float32, SoftmaxCnnl,
"Softmax_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
"Sigmoid_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl,
"Softmax_cnnl_BANG");
}; // namespace infini

View File

@ -10,6 +10,7 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ActivationBackwardObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const yData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -81,11 +82,11 @@ class TanhBackwardCnnl : public ActivationBackwardCnnl {
float getCoef() const override { return 0.0; }
};
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32,
ReluBackwardCnnl, "ReluBackward_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32,
SigmoidBackwardCnnl, "SigmoidBackward_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32,
TanhBackwardCnnl, "TanhBackward_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, ReluBackwardCnnl,
"ReluBackward_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, SigmoidBackwardCnnl,
"SigmoidBackward_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, TanhBackwardCnnl,
"TanhBackward_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<BatchNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
@ -101,7 +102,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32,
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, BatchNormCnnl,
"BatchNorm_cnnl_BANG");
}; // namespace infini

View File

@ -212,7 +212,6 @@ class CastCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl,
"Cast_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Cast, CastCnnl, "Cast_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class CeilCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,6 @@ class CeilCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Ceil, DataType::Float32, CeilCnnl,
"Ceil_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Ceil, CeilCnnl, "Ceil_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ClipCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ClipObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -30,7 +31,6 @@ class ClipCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Clip, DataType::Float32, ClipCnnl,
"Clip_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Clip, ClipCnnl, "Clip_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ConcatCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConcatObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numInputs();
int axis = op->getDim();
@ -50,6 +51,5 @@ class ConcatCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Concat, DataType::Float32, ConcatCnnl,
"Concat_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Concat, ConcatCnnl, "Concat_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -151,6 +152,5 @@ class ConvCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Conv, DataType::Float32, ConvCnnl,
"Conv_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Conv, ConvCnnl, "Conv_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConvBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -76,6 +77,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32,
ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, ConvTransCnnl,
"ConvTrans_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ConvBackwardFilterObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
@ -154,6 +155,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32,
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter,
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class DetCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<DetObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -42,6 +43,5 @@ class DetCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl,
"Det_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Det, DetCnnl, "Det_cnnl_BANG");
}; // namespace infini

View File

@ -11,6 +11,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -74,6 +75,7 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -127,6 +129,7 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -179,6 +182,7 @@ class DivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -231,6 +235,7 @@ class MaximumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -282,6 +287,7 @@ class MinimumCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -333,6 +339,7 @@ class MSELossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<MSELossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -389,6 +396,7 @@ class PowerCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -442,6 +450,7 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -494,6 +503,7 @@ class FloorModCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -546,6 +556,7 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -658,62 +669,48 @@ class BitNotCnnl : public BitComputeCnnl {
// CNNL_BLEFT_SHIFT_OP_V2; }
// };
REGISTER_KERNEL(Device::BANG, OpType::Add, DataType::Float32, AddCnnl,
"Add_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sub, DataType::Float32, SubCnnl,
"Sub_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
"Mul_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Add, AddCnnl, "Add_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sub, SubCnnl, "Sub_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Mul, MulCnnl, "Mul_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl,
"Div_cnnl_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Max, DataType::Float32, MaximumCnnl,
"Maximum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Min, DataType::Float32, MinimumCnnl,
"Minimum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl,
"MSELoss_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, PowerCnnl,
"Power_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl,
"FloorDiv_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorMod, DataType::Float32, FloorModCnnl,
"FloorMod_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, DataType::Float32,
SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl,
"Equal_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Greater, DataType::Float32,
GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, DataType::Float32,
GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Less, DataType::Float32, LessThanCnnl,
"LessThan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, DataType::Float32,
LessEqualCnnl, "LessEqual_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl,
"And_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Or, DataType::Float32, OrCnnl,
"Or_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Xor, DataType::Float32, XorCnnl,
"Xor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl,
"Not_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, DataType::Float32, BitAndCnnl,
"BitAnd_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, DataType::Float32, BitOrCnnl,
"BitOr_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, DataType::Float32, BitXorCnnl,
"BitXor_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, DataType::Float32, BitNotCnnl,
"BitNot_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32,
REGISTER_KERNEL(Device::BANG, OpType::Div, DivCnnl, "Div_cnnl");
REGISTER_KERNEL(Device::BANG, OpType::Max, MaximumCnnl, "Maximum_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Min, MinimumCnnl, "Minimum_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, MSELossCnnl,
"MSELoss_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Pow, PowerCnnl, "Power_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, FloorDivCnnl,
"FloorDiv_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::FloorMod, FloorModCnnl,
"FloorMod_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, SquaredDifferenceCnnl,
"SquaredDifference_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Equal, EqualCnnl, "Equal_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Greater, GreaterThanCnnl,
"GreaterThan_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, GreaterEqualCnnl,
"GreaterEqual_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Less, LessThanCnnl, "LessThan_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, LessEqualCnnl,
"LessEqual_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::And, AndCnnl, "And_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Or, OrCnnl, "Or_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Xor, XorCnnl, "Xor_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Not, NotCnnl, "Not_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, BitAndCnnl,
"BitAnd_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, BitOrCnnl, "BitOr_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, BitXorCnnl,
"BitXor_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, BitNotCnnl,
"BitNot_cnnl_BANG");
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift,
// BitLeftShiftCnnl,
// "BitLeftShift_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift, DataType::Float32,
// "BitLeftShift_cnnl_BANG");
// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift,
// BitRightShiftCnnl,
// "BitRightShift_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32,
// "BitRightShift_cnnl_BANG");
// REGISTER_KERNEL(Device::BANG, OpType::Pow,
// ElementWiseBang,
// "Pow_Bang_Float32");
// "Pow_Bang");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ErfCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class ErfCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Erf, DataType::Float32, ErfCnnl,
"Erf_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Erf, ErfCnnl, "Erf_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ExpCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class ExpCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Exp, DataType::Float32, ExpCnnl,
"Exp_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Exp, ExpCnnl, "Exp_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class FillCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<FillObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -29,7 +30,6 @@ class FillCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Fill, DataType::Float32, FillCnnl,
"Fill_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Fill, FillCnnl, "Fill_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class FloorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,7 @@ class FloorCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Floor, DataType::Float32, FloorCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Floor, FloorCnnl,
"Floor_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class GatherCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<GatherObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -49,7 +50,6 @@ class GatherCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Gather, DataType::Float32, GatherCnnl,
"Gather_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Gather, GatherCnnl, "Gather_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<HardtanhObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -30,7 +31,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, DataType::Float32, HardtanhCnnl,
"Hardtanh_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, HardtanhCnnl,
"Hardtanh_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class L2LossCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<L2LossObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -28,7 +29,6 @@ class L2LossCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::L2Loss, DataType::Float32, L2LossCnnl,
"L2Loss_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::L2Loss, L2LossCnnl, "L2Loss_cnnl_BANG");
}; // namespace infini

View File

@ -8,6 +8,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LayerNormObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -58,7 +59,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, DataType::Float32,
LayerNormCnnl, "LayerNorm_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, LayerNormCnnl,
"LayerNorm_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class LogCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LogObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -51,7 +52,6 @@ class LogCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Log, DataType::Float32, LogCnnl,
"Log_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Log, LogCnnl, "Log_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class LRNCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<LRNObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -56,7 +57,6 @@ class LRNCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::LRN, DataType::Float32, LRNCnnl,
"LRN_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::LRN, LRNCnnl, "LRN_cnnl_BANG");
}; // namespace infini

View File

@ -8,6 +8,7 @@ class MatmulCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
auto input_num = op->numInputs();
@ -107,6 +108,5 @@ class MatmulCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::MatMul, DataType::Float32, MatmulCnnl,
"Matmul_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::MatMul, MatmulCnnl, "Matmul_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Neg, DataType::Float32, NegTensorCnnl,
"Neg_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Neg, NegTensorCnnl, "Neg_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class PadCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PadObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -57,7 +58,6 @@ class PadCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Pad, DataType::Float32, PadCnnl,
"Pad_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Pad, PadCnnl, "Pad_cnnl_BANG");
}; // namespace infini

View File

@ -8,6 +8,7 @@ class PoolingCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -68,8 +69,8 @@ class avgPoolCnnl : public PoolingCnnl {
}
};
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl,
"MaxPool_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::AveragePool, DataType::Float32,
avgPoolCnnl, "AvgPool_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, maxPoolCnnl,
"MaxPool_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::AveragePool, avgPoolCnnl,
"AvgPool_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -35,7 +36,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, DataType::Float32,
ReciprocalCnnl, "Reciprocal_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, ReciprocalCnnl,
"Reciprocal_cnnl_BANG");
}; // namespace infini

View File

@ -9,6 +9,7 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceBaseObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -73,9 +74,9 @@ class ReduceSumCnnl : public ReduceCnnlBase {
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; }
};
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32,
ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, ReduceMeanCnnl,
"ReduceMean_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, ReduceSumCnnl,
"ReduceSum_cnnl_BANG");
}; // namespace infini

View File

@ -13,9 +13,9 @@ class CopyBang : public BangKernelWithoutConfig {
auto dim = op->getInputs(0)->getDims();
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dim.size(),
dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
dim.size() * op->getDType().getSize(), dim.data()));
cnnlStatus_t stat =
cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData);
if (stat != CNNL_STATUS_SUCCESS)
@ -25,13 +25,8 @@ class CopyBang : public BangKernelWithoutConfig {
}
};
// reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang,
"Reshape_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang,
"Reshape_BANG_Int64");
REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang,
"Flatten_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang,
"Identity_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG");
} // namespace infini

View File

@ -7,6 +7,7 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, DataType::Float32, RsqrtCnnl,
"Rsqrt_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, RsqrtCnnl, "Rsqrt_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class SplitCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<SplitObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numOutputs();
int axis = op->getDim();
@ -49,6 +50,5 @@ class SplitCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Split, DataType::Float32, SplitCnnl,
"Split_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Split, SplitCnnl, "Split_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class SqrtCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -36,7 +37,6 @@ class SqrtCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Sqrt, DataType::Float32, SqrtCnnl,
"Sqrt_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sqrt, SqrtCnnl, "Sqrt_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class TransposeCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<TransposeObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -52,6 +53,7 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<DepthToSpaceObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -101,9 +103,9 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32,
TransposeCnnl, "Transpose_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Transpose, TransposeCnnl,
"Transpose_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DataType::Float32,
DepthToSpaceCnnl, "DepthToSpace_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DepthToSpaceCnnl,
"DepthToSpace_cnnl_BANG");
}; // namespace infini

View File

@ -9,6 +9,7 @@ class TrigonCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -150,29 +151,17 @@ class ATanHCnnl : public TrigonCnnl {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Sin, DataType::Float32, SinCnnl,
"Sin_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Cos, DataType::Float32, CosCnnl,
"Cos_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl,
"Tan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Asin, DataType::Float32, ASinCnnl,
"ASin_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Acos, DataType::Float32, ACosCnnl,
"ACos_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Atan, DataType::Float32, ATanCnnl,
"ATan_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sinh, DataType::Float32, SinHCnnl,
"SinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Cosh, DataType::Float32, CosHCnnl,
"CosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanHCnnl,
"TanH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Asinh, DataType::Float32, ASinHCnnl,
"ASinH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Acosh, DataType::Float32, ACosHCnnl,
"ACosH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Atanh, DataType::Float32, ATanHCnnl,
"ATanH_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Sin, SinCnnl, "Sin_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Cos, CosCnnl, "Cos_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Tan, TanCnnl, "Tan_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Asin, ASinCnnl, "ASin_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Acos, ACosCnnl, "ACos_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Atan, ATanCnnl, "ATan_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Sinh, SinHCnnl, "SinH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Cosh, CosHCnnl, "CosH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Tanh, TanHCnnl, "TanH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Asinh, ASinHCnnl, "ASinH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Acosh, ACosHCnnl, "ACosH_cnnl_BANG");
REGISTER_KERNEL(Device::BANG, OpType::Atanh, ATanHCnnl, "ATanH_cnnl_BANG");
}; // namespace infini

View File

@ -7,6 +7,7 @@ class WhereCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
@ -67,7 +68,6 @@ class WhereCnnl : public BangKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::BANG, OpType::Where, DataType::Float32, WhereCnnl,
"Where_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Where, WhereCnnl, "Where_cnnl_BANG");
}; // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveConcat : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ConcatObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs();
auto dim = op->getDim();
@ -41,11 +41,25 @@ template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::UInt32,
NaiveConcat<uint32_t>, "ConcatNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::Float32,
NaiveConcat<float>, "ConcatNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Concat, NaiveConcat, "ConcatNaive_CPU");
} // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveConv : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ConvObj>(_op);
T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
T *wptr = op->getInputs(1)->getRawDataPtr<T *>();
@ -50,11 +50,25 @@ template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32,
NaiveConv<uint32_t>, "ConvNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv<float>,
"ConvNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Conv, NaiveConv, "ConvNaive_CPU");
} // namespace infini

View File

@ -3,10 +3,45 @@
#include "utils/operator_utils.h"
namespace infini {
template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
virtual T doCompute(T val0, T val1) const = 0;
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NativeElementWise : public CpuKernelWithoutConfig {
template <typename T> static T addCompute(T val0, T val1) {
return val0 + val1;
}
template <typename T> static T subCompute(T val0, T val1) {
return val0 - val1;
}
template <typename T> static T mulCompute(T val0, T val1) {
return val0 * val1;
}
template <typename T> static T divCompute(T val0, T val1) {
return (T)(val0 / val1);
}
template <typename T> static T equalCompute(T val0, T val1) {
return (T)(val0 == val1);
}
template <typename T> static T greaterOrEqualCompute(T val0, T val1) {
return (T)(val0 >= val1);
}
template <typename T> static T greaterCompute(T val0, T val1) {
return (T)(val0 > val1);
}
template <typename T> static T lessOrEqualCompute(T val0, T val1) {
return (T)(val0 <= val1);
}
template <typename T> static T lessCompute(T val0, T val1) {
return (T)(val0 < val1);
}
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ElementWiseObj>(_op);
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>();
@ -35,77 +70,77 @@ template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
Shape strideB = getStride(b);
auto n = op->getOutput()->size();
T (*_doCompute)(T val0, T val1);
switch (op->getOpType().underlying()) {
case OpType::Add:
_doCompute = addCompute<T>;
break;
case OpType::Sub:
_doCompute = subCompute<T>;
break;
case OpType::Mul:
_doCompute = mulCompute<T>;
break;
case OpType::Div:
_doCompute = divCompute<T>;
break;
case OpType::Equal:
_doCompute = equalCompute<T>;
break;
case OpType::GreaterOrEqual:
_doCompute = greaterOrEqualCompute<T>;
break;
case OpType::Greater:
_doCompute = greaterCompute<T>;
break;
case OpType::LessOrEqual:
_doCompute = lessOrEqualCompute<T>;
break;
case OpType::Less:
_doCompute = lessCompute<T>;
break;
default:
IT_TODO_HALT();
}
for (size_t i = 0; i < n; ++i) {
auto shapeIndexC = locate_index(i, shapeC);
auto indexA = delocate_index(shapeIndexC, a, strideA);
auto indexB = delocate_index(shapeIndexC, b, strideB);
outptr[i] = doCompute(inptr0[indexA], inptr1[indexB]);
outptr[i] = _doCompute(inptr0[indexA], inptr1[indexB]);
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
template <typename T> class NaiveAdd : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 + val1; }
};
template <typename T> class NaiveSub : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 - val1; }
};
template <typename T> class NaiveMul : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return val0 * val1; }
};
template <typename T> class NaiveDiv : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); }
};
template <typename T> class NaiveEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 == val1); }
};
template <typename T> class NaiveGreaterEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); }
};
template <typename T> class NaiveGreaterThan : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 > val1); }
};
template <typename T> class NaiveLessEqual : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); }
};
template <typename T> class NaiveLessThan : public NativeElementWise<T> {
T doCompute(T val0, T val1) const override { return (T)(val0 < val1); }
};
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
"addNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd<float>,
"addNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub<uint32_t>,
"subNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub<float>,
"subNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul<uint32_t>,
"mulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul<float>,
"mulNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
"divNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
"divNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32,
NaiveEqual<uint32_t>, "equalNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32,
NaiveEqual<float>, "equalNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32,
NaiveGreaterEqual<uint32_t>, "greaterEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32,
NaiveGreaterEqual<float>, "greaterEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32,
NaiveGreaterThan<uint32_t>, "greaterThanNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32,
NaiveGreaterThan<float>, "greaterThanNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32,
NaiveLessEqual<uint32_t>, "lessEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32,
NaiveLessEqual<float>, "lessEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32,
NaiveLessThan<uint32_t>, "lessEqualNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32,
NaiveLessThan<float>, "lessEqualNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Add, NativeElementWise, "addNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sub, NativeElementWise, "subNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Mul, NativeElementWise, "mulNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Div, NativeElementWise, "divNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Equal, NativeElementWise,
"equalNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, NativeElementWise,
"greaterEqualNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Greater, NativeElementWise,
"greaterThanNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, NativeElementWise,
"lessEqualNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Less, NativeElementWise,
"lessEqualNaive_CPU");
}; // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveMatmul : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<MatmulObj>(_op);
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
T *A = op->getInputs(0)->getRawDataPtr<T *>();
@ -23,11 +23,25 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32,
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32,
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::MatMul, NaiveMatmul, "MatmulNaive_CPU");
} // namespace infini

View File

@ -80,8 +80,8 @@ class MemboundInterpreter : public Kernel {
}
};
REGISTER_KERNEL(Device::CPU, OpType::MemBound, DataType::UInt32,
MemboundInterpreter, "MemboundInterpreter_CPU");
REGISTER_KERNEL(Device::CPU, OpType::MemBound, MemboundInterpreter,
"MemboundInterpreter_CPU");
} // namespace infini

View File

@ -2,42 +2,10 @@
#include "core/kernel.h"
namespace infini {
template <typename T> class NativePooling : public CpuKernelWithoutConfig {
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) const = 0;
void compute(const Operator &_op,
const RuntimeObj *context) const override {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh != 1 || dw != 1)
IT_TODO_HALT(); // To support dailated pooling
auto outDim = op->getOutput()->getDims();
int oh = outDim[2], ow = outDim[3];
for (auto i = 0; i < n; i++) {
for (auto j = 0; j < c; j++) {
auto inoffset = i * (c * ih * iw) + j * ih * iw;
for (auto h = 0; h < oh; h++) {
for (auto w = 0; w < ow; w++) {
// TODO: verify ceil mode
T val =
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
ih, iw, inptr + inoffset);
auto outoffset =
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
outptr[outoffset] = val;
}
}
}
}
}
};
template <typename T> class NaiveMaxPool : public NativePooling<T> {
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
T *inptr) const override {
class NativePooling : public CpuKernelWithoutConfig {
template <typename T>
static T getMaxPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) {
T maxval = 0;
for (auto k = 0; k < kh; k++) {
for (auto l = 0; l < kw; l++) {
@ -53,11 +21,10 @@ template <typename T> class NaiveMaxPool : public NativePooling<T> {
}
return maxval;
}
};
template <typename T> class NaiveAvgPool : public NativePooling<T> {
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
T *inptr) const override {
template <typename T>
static T getAvgPoolingValue(int kh, int kw, int posh, int posw, int ih,
int iw, T *inptr) {
T sum = 0;
for (auto k = 0; k < kh; k++) {
for (auto l = 0; l < kw; l++) {
@ -71,12 +38,70 @@ template <typename T> class NaiveAvgPool : public NativePooling<T> {
}
return T(sum / (kh * kw));
}
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<PoolingObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh != 1 || dw != 1)
IT_TODO_HALT(); // To support dailated pooling
auto outDim = op->getOutput()->getDims();
int oh = outDim[2], ow = outDim[3];
T(*_doCompute)
(int kh, int kw, int posh, int posw, int ih, int iw, T *inptr);
switch (op->getOpType().underlying()) {
case OpType::MaxPool:
_doCompute = getMaxPoolingValue<T>;
break;
case OpType::AveragePool:
_doCompute = getAvgPoolingValue<T>;
break;
default:
IT_TODO_HALT();
}
for (auto i = 0; i < n; i++) {
for (auto j = 0; j < c; j++) {
auto inoffset = i * (c * ih * iw) + j * ih * iw;
for (auto h = 0; h < oh; h++) {
for (auto w = 0; w < ow; w++) {
// TODO: verify ceil mode
T val = _doCompute(kh, kw, h * sh - ph, w * sw - pw, ih,
iw, inptr + inoffset);
auto outoffset =
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
outptr[outoffset] = val;
}
}
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32,
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32,
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32,
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, NativePooling,
"maxPoolNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, NativePooling,
"avgPoolNaive_CPU");
} // namespace infini

View File

@ -3,9 +3,9 @@
namespace infini {
template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveSplit : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<SplitObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs();
auto dim = op->getDim();
@ -40,11 +40,24 @@ template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::UInt32,
NaiveSplit<uint32_t>, "SplitNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::Float32,
NaiveSplit<float>, "SplitNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Split, NaiveSplit, "SplitNaive_CPU");
} // namespace infini

View File

@ -14,9 +14,9 @@ inline Shape idx2Pos(const Shape &shape, size_t idx) {
return pos;
}
template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveTranspose : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<TransposeObj>(_op);
auto inputs = op->getInputs(), outputs = op->getOutputs();
const auto &inDim = inputs[0]->getDims();
@ -35,11 +35,26 @@ template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
outPtr[outIdx] = inPtr[inIdx];
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::UInt32,
NaiveTranspose<uint32_t>, "TransposeNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::Float32,
NaiveTranspose<float>, "TransposeNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Transpose, NaiveTranspose,
"TransposeNaive_CPU");
} // namespace infini

View File

@ -4,25 +4,170 @@
#include "operators/softmax.h"
namespace infini {
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
virtual T doCompute(T val) const = 0;
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NativeUnary : public CpuKernelWithoutConfig {
template <typename T> static T reluCompute(T val) {
return std::max(T(0), val);
}
template <typename T> static T sigmoidCompute(T val) {
return 1 / (1 + pow(E_CONSTANT, -val));
}
template <typename T> static T hardSigmoidCompute(T val) {
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
}
template <typename T> static T hardSwishCompute(T val) {
return val *
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
}
template <typename T> static T tanhCompute(T val) {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
template <typename T> static T absCompute(T val) {
return val < 0 ? -val : val;
}
template <typename T> static T sqrtCompute(T val) { return std::sqrt(val); }
template <typename T> static T cosCompute(T val) { return std::cos(val); }
template <typename T> static T sinCompute(T val) { return std::sin(val); }
template <typename T> static T tanCompute(T val) { return std::tan(val); }
template <typename T> static T sinhCompute(T val) { return std::sinh(val); }
template <typename T> static T coshCompute(T val) { return std::cosh(val); }
template <typename T> static T geluCompute(T val) {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
}
template <typename T> static T erfCompute(T val) { return std::erf(val); }
template <typename T> static T aCosCompute(T val) { return std::acos(val); }
template <typename T> static T aCoshCompute(T val) {
return std::acosh(val);
}
template <typename T> static T aSinCompute(T val) { return std::asin(val); }
template <typename T> static T aSinhCompute(T val) {
return std::asinh(val);
}
template <typename T> static T aTanCompute(T val) { return std::atan(val); }
template <typename T> static T aTanhCompute(T val) {
return std::atanh(val);
}
template <typename T> static T negCompute(T val) { return -val; }
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
auto outDim = op->getOutput()->getDims();
auto n = op->getOutput()->size();
T (*_doCompute)(T val);
switch (op->getOpType().underlying()) {
case OpType::Relu:
_doCompute = reluCompute<T>;
break;
case OpType::Gelu:
_doCompute = geluCompute<T>;
break;
case OpType::Sigmoid:
_doCompute = sigmoidCompute<T>;
break;
case OpType::HardSigmoid:
_doCompute = hardSigmoidCompute<T>;
break;
case OpType::HardSwish:
_doCompute = hardSwishCompute<T>;
break;
case OpType::Tanh:
_doCompute = tanhCompute<T>;
break;
case OpType::Abs:
_doCompute = absCompute<T>;
break;
case OpType::Sqrt:
_doCompute = sqrtCompute<T>;
break;
case OpType::Erf:
_doCompute = erfCompute<T>;
break;
case OpType::Neg:
_doCompute = negCompute<T>;
break;
case OpType::Cos:
_doCompute = cosCompute<T>;
break;
case OpType::Sin:
_doCompute = sinCompute<T>;
break;
case OpType::Tan:
_doCompute = tanCompute<T>;
break;
case OpType::Sinh:
_doCompute = sinhCompute<T>;
break;
case OpType::Cosh:
_doCompute = coshCompute<T>;
break;
case OpType::Acos:
_doCompute = aCosCompute<T>;
break;
case OpType::Asin:
_doCompute = aSinCompute<T>;
break;
case OpType::Asinh:
_doCompute = aSinhCompute<T>;
break;
case OpType::Atan:
_doCompute = aTanCompute<T>;
break;
case OpType::Atanh:
_doCompute = aTanhCompute<T>;
break;
default:
IT_TODO_HALT();
}
for (size_t offset = 0; offset < n; offset++) {
outptr[offset] = doCompute(inptr[offset]);
outptr[offset] = _doCompute(inptr[offset]);
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
class NaiveSoftmax : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<SoftmaxObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -37,98 +182,28 @@ template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
}
}
};
template <typename T> class NaiveRelu : public NativeUnary<T> {
T doCompute(T val) const override { return std::max(T(0), val); }
};
template <typename T> class NaiveSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return 1 / (1 + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveHardSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
}
};
template <typename T> class NaiveHardSwish : public NativeUnary<T> {
T doCompute(T val) const override {
return val *
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
}
};
template <typename T> class NaiveTanh : public NativeUnary<T> {
T doCompute(T val) const override {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveAbs : public NativeUnary<T> {
T doCompute(T val) const override { return val < 0 ? -val : val; }
};
template <typename T> class NaiveSqrt : public NativeUnary<T> {
T doCompute(T val) const override { return std::sqrt(val); }
};
template <typename T> class NaiveCos : public NativeUnary<T> {
T doCompute(T val) const override { return std::cos(val); }
};
template <typename T> class NaiveSin : public NativeUnary<T> {
T doCompute(T val) const override { return std::sin(val); }
};
template <typename T> class NaiveTan : public NativeUnary<T> {
T doCompute(T val) const override { return std::tan(val); }
};
template <typename T> class NaiveSinh : public NativeUnary<T> {
T doCompute(T val) const override { return std::sinh(val); }
};
template <typename T> class NaiveCosh : public NativeUnary<T> {
T doCompute(T val) const override { return std::cosh(val); }
};
template <typename T> class NaiveGelu : public NativeUnary<T> {
T doCompute(T val) const override {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
}
};
template <typename T> class NaiveErf : public NativeUnary<T> {
T doCompute(T val) const override { return std::erf(val); }
};
template <typename T> class NaiveACos : public NativeUnary<T> {
T doCompute(T val) const override { return std::acos(val); }
};
template <typename T> class NaiveACosh : public NativeUnary<T> {
T doCompute(T val) const override { return std::acosh(val); }
};
template <typename T> class NaiveASin : public NativeUnary<T> {
T doCompute(T val) const override { return std::asin(val); }
};
template <typename T> class NaiveASinh : public NativeUnary<T> {
T doCompute(T val) const override { return std::asinh(val); }
};
template <typename T> class NaiveATanh : public NativeUnary<T> {
T doCompute(T val) const override { return std::atanh(val); }
};
template <typename T> class NaiveNeg : public NativeUnary<T> {
T doCompute(T val) const override { return -val; }
};
template <typename T> class Clip : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
class Clip : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<ClipObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -143,11 +218,28 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
: val;
}
}
};
template <typename T> class Log : public CpuKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
class Log : public CpuKernelWithoutConfig {
template <typename T>
void doCompute(const Operator &_op, const RuntimeObj *context) const {
auto op = as<LogObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
@ -176,70 +268,50 @@ template <typename T> class Log : public CpuKernelWithoutConfig {
}
}
}
void compute(const Operator &_op,
const RuntimeObj *context) const override {
#define CASE(N) \
case N: \
doCompute<DT<N>::t>(_op, context)
int dataTypeIdx = _op->getDType().getIndex();
switch (dataTypeIdx) {
CASE(1); // DataType::Float32
break;
CASE(12); // DataType::UInt32
break;
default:
IT_TODO_HALT();
}
}
};
template <typename T> class NaiveATan : public NativeUnary<T> {
T doCompute(T val) const override { return std::atan(val); }
};
REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary,
"hardSigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, NativeUnary,
"hardSwishNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, NativeUnary, "tanhNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Abs, NativeUnary, "absNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, NativeUnary, "sqrtNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Erf, NativeUnary, "erfNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Neg, NativeUnary, "negNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Cos, NativeUnary, "Cos_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sin, NativeUnary, "Sin_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Tan, NativeUnary, "Tan_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sinh, NativeUnary, "Sinh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Cosh, NativeUnary, "Cosh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Acos, NativeUnary, "ACos_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Acosh, NativeUnary, "ACosh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Asin, NativeUnary, "ASin_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Asinh, NativeUnary, "ASinh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Atan, NativeUnary, "Atan_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Atanh, NativeUnary, "ATanh_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
"reluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu<float>,
"geluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu<float>,
"geluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32,
NaiveHardSigmoid<float>, "hardSigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32,
NaiveHardSwish<float>, "hardSwishNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
"tanhNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
"absNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
"absNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
"sqrtNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
"erfNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg<float>,
"negNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip<float>,
"Clip_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan<float>,
"Atan_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log<float>,
"Log_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos<float>,
"Cos_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin<float>,
"Sin_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan<float>,
"Tan_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh<float>,
"Sinh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh<float>,
"Cosh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos<float>,
"ACos_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32,
NaiveACosh<float>, "ACosh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin<float>,
"ASin_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32,
NaiveASinh<float>, "ASinh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32,
NaiveATanh<float>, "ATanh_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, NaiveSoftmax, "softmaxNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Clip, Clip, "Clip_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Log, Log, "Log_CPU");
}; // namespace infini

View File

@ -48,13 +48,13 @@ class G2BMMCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<G2BMMObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = g2bmmKernel(op, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, DataType::Float32, G2BMMCudnn,
"G2BMM_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, G2BMMCudnn, "G2BMM_cuDNN_CUDA");
} // namespace infini

View File

@ -49,13 +49,13 @@ class GBMMCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<GBMMObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = gbmmKernel(op, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn,
"GBMM_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, GBMMCudnn, "GBMM_cuDNN_CUDA");
} // namespace infini

View File

@ -39,8 +39,8 @@ class AllGatherNCCL : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::AllGather, DataType::Float32,
AllGatherNCCL, "AllGather_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AllGather, AllGatherNCCL,
"AllGather_NCCL_CUDA");
} // namespace infini
#endif

View File

@ -13,15 +13,24 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *input = op->getInputs(0)->getRawDataPtr<void *>();
void *output = op->getOutput()->getRawDataPtr<void *>();
IT_ASSERT(op->getDType() == DataType::Float32);
ncclDataType_t ncclType = ncclFloat;
if (op->getDType() == DataType::Float16) {
ncclType = ncclFloat16;
} else if (op->getDType() == DataType::Int8) {
ncclType = ncclInt8;
} else if (op->getDType() == DataType::Float32) {
ncclType = ncclFloat;
} else {
IT_TODO_HALT();
}
size_t count = op->getInputs(0)->size();
ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
.getNcclComm();
// TODO: Using default stream 0 for now.
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
getRedOp(), comm, 0));
checkNcclError(
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
}
virtual ncclRedOp_t getRedOp() const = 0;
@ -43,16 +52,16 @@ class AllReduceAvgNCCL : public AllReduceNCCL {
ncclRedOp_t getRedOp() const override { return ncclAvg; }
};
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, DataType::Float32,
AllReduceSumNCCL, "AllReduce_Sum_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, DataType::Float32,
AllReduceProdNCCL, "AllReduce_Prod_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, DataType::Float32,
AllReduceMinNCCL, "AllReduce_Min_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, DataType::Float32,
AllReduceMaxNCCL, "AllReduce_Max_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, DataType::Float32,
AllReduceAvgNCCL, "AllReduce_Avg_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, AllReduceSumNCCL,
"AllReduce_Sum_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, AllReduceProdNCCL,
"AllReduce_Prod_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, AllReduceMinNCCL,
"AllReduce_Min_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, AllReduceMaxNCCL,
"AllReduce_Max_NCCL_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, AllReduceAvgNCCL,
"AllReduce_Avg_NCCL_CUDA");
} // namespace infini
#endif

View File

@ -40,6 +40,7 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32);
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
@ -47,6 +48,6 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
}
};
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, AttentionKVCacheCuda,
"AttentionKVCache_CUDA");
} // namespace infini

View File

@ -10,6 +10,7 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
auto op = as<BatchNormObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
cudnnStatus_t stat;
IT_ASSERT(op->getDType() == DataType::Float32);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
void *const meanData = (op->getInputs(1)->getRawDataPtr<void *>());
@ -59,6 +60,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, DataType::Float32,
BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, BatchNormCudnn,
"BatchNorm_cuDNN_CUDA");
} // namespace infini

View File

@ -25,8 +25,8 @@ class BroadcastNCCL : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, DataType::Float32,
BroadcastNCCL, "Broadcast_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, BroadcastNCCL,
"Broadcast_NCCL_CUDA");
} // namespace infini
#endif

View File

@ -9,7 +9,7 @@ class ClipCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ClipObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
auto min = op->getMin();
@ -21,7 +21,6 @@ class ClipCuda : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Clip, DataType::Float32, ClipCuda,
"Clip_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Clip, ClipCuda, "Clip_CUDA");
}; // namespace infini

View File

@ -1,10 +1,12 @@
#include "operators/conv.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include <chrono>
#include <functional>
#include <limits>
#include <tuple>
namespace infini {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
@ -56,8 +58,11 @@ class convCudnn : public Kernel {
const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
if (op->getInputs().size() > 2) // Bias is not supported yet
// Bias is not supported yet
if (op->getInputs().size() > 2) {
IT_TODO_HALT();
}
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -72,27 +77,26 @@ class convCudnn : public Kernel {
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, n, channels, h, w));
// get kernels
cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW, f,
channelsPerGrp, r, s));
checkCudnnError(cudnnSetFilter4dDescriptor(
knDesc, cudnnDataType, CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
// get bias
cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
cudnnDataType, 1, f, 1, 1));
// get convlution descriptor
// get convolution descriptor
cudnnConvolutionDescriptor_t convDesc;
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
CUDNN_DATA_FLOAT));
cudnnDataType));
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
}
@ -120,14 +124,14 @@ class convCudnn : public Kernel {
assert(false);
}
// get output descriptor
int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, outn, outc,
outh, outw));
checkCudnnError(cudnnSetTensor4dDescriptor(
outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, outn, outc, outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
@ -151,55 +155,9 @@ class convCudnn : public Kernel {
inData, knDesc, knData, convDesc,
ALGOS[record->algo], wsData, wsSize,
&beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
if (stat != CUDNN_STATUS_SUCCESS) {
return false;
// TODO:
// // bias
// if (bias != nullptr) {
// auto sz = op.getOutputs()[0]->size();
// // TODO: element wise
// t += sz * 2 / 400;
// }
// // act
// if (act != None) {
// stat = cudnnActivationForward(cudnnHandle(), actDesc,
// &alpha, inDesc, inData,
// &beta, outDesc, outData);
// checkCudaError(cudaDeviceSynchronize());
// end = ch::high_resolution_clock::now();
// if (stat != CUDNN_STATUS_SUCCESS) {
// durtime = INFINITY;
// break;
// }
// t +=
// ch::duration_cast<ch::duration<double>>(end -
// beg).count() * 1000; // ms
// }
// best = ConvResult{durtime, ALGOS[i], wsSize, false};
// // w/ bias & act
// for (int j = 0; j < rounds + warmupRounds; ++j) {
// cudnnStatus_t stat;
// if (j == warmupRounds) {
// checkCudaError(cudaDeviceSynchronize());
// beg = ch::high_resolution_clock::now();
// }
// stat = cudnnConvolutionBiasActivationForward(
// cudnnHandle(), &alpha, inDesc, inData, knDesc, knData,
// convDesc, ALGOS[i], wsData, wsSize, &beta, outDesc,
// outData, biasDesc, biasData, actDesc, outDesc, outData);
// if (stat != CUDNN_STATUS_SUCCESS) {
// // checkCudnnError(stat);
// // Do not checkCudnnError since not all algorithms are
// // supported
// durtime_fuse = INFINITY;
// break;
// }
// }
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
@ -238,10 +196,12 @@ class convCudnn : public Kernel {
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS)
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
if (record.workspaceSize > context->getWorkspaceSize())
}
if (record.workspaceSize > context->getWorkspaceSize()) {
continue;
}
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;
@ -249,8 +209,9 @@ class convCudnn : public Kernel {
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
record.time = timeit(
[&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
@ -263,8 +224,9 @@ class convCudnn : public Kernel {
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Update the tune result
if (ret.time > record.time)
if (ret.time > record.time) {
ret = record;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
@ -291,8 +253,7 @@ class convCudnn : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn,
"Conv_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Conv, convCudnn, "Conv_cuDNN_CUDA");
REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json);
} // namespace infini

View File

@ -1,261 +0,0 @@
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "operators/conv.h"
#include <chrono>
#include <functional>
#include <limits>
#include <tuple>
namespace infini {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
int algo = 0; // cudnnConvolutionFwdAlgo_t
int mode = 1;
size_t workspaceSize = 100000;
bool fuseAct = false;
void to_json(json &j) override {
j["type"] = 1;
j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize);
}
static PerfRecord from_json(const json &j) {
ConvCuDnnPerfRecordObj tmp;
auto [Algo, Mode, FuseAct, Time, WorkspaceSize] =
j["data"].get<tuple<int, int, bool, double, size_t>>();
tmp.algo = Algo;
tmp.mode = Mode;
tmp.fuseAct = FuseAct;
tmp.time = Time;
tmp.workspaceSize = WorkspaceSize;
return make_ref<ConvCuDnnPerfRecordObj>(tmp);
}
};
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
class convCudnnFP16 : public Kernel {
static constexpr int N_ALGO = 8;
static constexpr int N_MODE = 2;
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = {
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
static constexpr cudnnConvolutionMode_t MODES[2] = {
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
cudnnTensorDescriptor_t>
createCuDNNDescriptor(const Ref<ConvObj> &op,
const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
// Bias is not supported yet
if (op->getInputs().size() > 2) {
IT_TODO_HALT();
}
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
const int cpg = op->getChannelPerGroup();
const int g = c / cpg;
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
int channelsPerGrp = cpg, channels = c;
// get inputs
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, n, channels,
h, w)); /*fp16 type*/
// get kernels
cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
checkCudnnError(cudnnSetFilter4dDescriptor(
knDesc, CUDNN_DATA_HALF, /*fp16 type*/
CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
// get bias
cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, 1, f, 1,
1)); /*fp16 type*/
// get convolution descriptor
cudnnConvolutionDescriptor_t convDesc;
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
CUDNN_DATA_HALF)); /*fp16 type*/
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
}
// get activation descriptor
cudnnActivationDescriptor_t actDesc;
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
// NOT_PROPAGATE_NAN is requierd by
// cudnnConvolotionBiasActivationForward
switch (op->getAct()) {
case ActType::Relu:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::Sigmoid:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::None:
checkCudnnError(
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
CUDNN_NOT_PROPAGATE_NAN, 0));
break;
default:
assert(false);
}
// get output descriptor
int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_HALF, outn, outc,
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc);
}
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
size_t wsSize = record->workspaceSize;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
inData, knDesc, knData, convDesc,
ALGOS[record->algo], wsData, wsSize,
&beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) {
return false;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
// default ctor
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<ConvObj>(_op);
// Both modes have the same performance. Only run cross-correlation.
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
for (int algo = 0; algo < N_ALGO; algo++) {
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
auto &record = *recordRef;
record.mode = mode;
record.algo = algo;
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, recordRef);
// get workspace
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
if (record.workspaceSize > context->getWorkspaceSize()) {
continue;
}
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS) {
continue;
}
record.time = timeit(
[&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
inDesc, inData, knDesc, knData,
convDesc, ALGOS[record.algo],
wsData, record.workspaceSize,
&beta, outDesc, outData);
},
[&]() { context->sync(); });
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Update the tune result
if (ret.time > record.time) {
ret = record;
}
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
}
}
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
// ret.mode);
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return make_ref<ConvCuDnnPerfRecordObj>(ret);
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
auto record = as<ConvCuDnnPerfRecordObj>(_record);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = cuDNNUnfused(op, record, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16,
"Conv_cuDNN_CUDA_Float16");
} // namespace infini

View File

@ -219,6 +219,7 @@ class convBackwardDataCudnn : public Kernel {
void compute(const Operator &op, const RuntimeObj *context) const override {
// with paramters in default ctor
auto record = make_ref<ConvTransposedCuDnnPerfRecordObj>();
IT_ASSERT(op->getDType() == DataType::Float32);
compute(op, record, context);
}
@ -300,8 +301,9 @@ class convBackwardDataCudnn : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, DataType::Float32,
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32,
convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, convBackwardDataCudnn,
"ConvTranposed_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, convBackwardDataCudnn,
"ConvTranposedNHWC_cuDNN_CUDA");
} // namespace infini

View File

@ -2,6 +2,7 @@
#include "cuda/cuda_element_wise.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
namespace infini {
class ElementWiseCudnn : public CudaKernelWithoutConfig {
@ -44,22 +45,21 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, a[0], a[1],
a[2], a[3]));
checkCudnnError(cudnnSetTensor4dDescriptor(
aDesc, CUDNN_TENSOR_NCHW, cudnnDataType, a[0], a[1], a[2], a[3]));
checkCudnnError(cudnnCreateTensorDescriptor(&bDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(bDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, b[0], b[1],
b[2], b[3]));
checkCudnnError(cudnnSetTensor4dDescriptor(
bDesc, CUDNN_TENSOR_NCHW, cudnnDataType, b[0], b[1], b[2], b[3]));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&cDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(cDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, c[0], c[1],
c[2], c[3]));
checkCudnnError(cudnnSetTensor4dDescriptor(
cDesc, CUDNN_TENSOR_NCHW, cudnnDataType, c[0], c[1], c[2], c[3]));
// get op descriptor
cudnnOpTensorDescriptor_t opDesc;
@ -127,40 +127,33 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
if (op->getOpType() == OpType::Div)
div_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]);
else if (op->getOpType() == OpType::Pow)
pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]);
else if (op->getOpType() == OpType::Add) {
add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]);
const int dType = _op->getDType().getIndex();
if (op->getOpType() == OpType::Div) {
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
} else if (op->getOpType() == OpType::Add) {
add_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
} else if (op->getOpType() == OpType::Pow) {
pow_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
} else if (op->getOpType() == OpType::Less) {
less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]);
} else
less_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3],
b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
} else {
IT_TODO_HALT();
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn,
"Add_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn,
"Sub_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn,
"Mul_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Min, DataType::Float32, MinCudnn,
"Min_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn,
"Max_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Add, AddCudnn, "Add_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Sub, SubCudnn, "Sub_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Mul, MulCudnn, "Mul_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Min, MinCudnn, "Min_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Max, MaxCudnn, "Max_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Div, ElementWiseCuda, "Div_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Pow, ElementWiseCuda, "Pow_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Less, ElementWiseCuda, "Less_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
"Div_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda,
"Add_CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
"Pow__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda,
"Less__CUDA_Int64");
}; // namespace infini

View File

@ -1,4 +1,5 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include <math.h>
constexpr unsigned int num_threads() { return 32 * 4; }
@ -129,44 +130,113 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
}
}
#define CASE(OP, T) \
_##OP##_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
#define SWITCH_DTYPE(OP, DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(OP, 1) \
break; \
case 2: \
CASE(OP, 2) \
break; \
case 3: \
CASE(OP, 3) \
break; \
case 4: \
CASE(OP, 4) \
break; \
case 5: \
CASE(OP, 5) \
break; \
case 6: \
CASE(OP, 6) \
break; \
case 7: \
CASE(OP, 7) \
break; \
case 10: \
CASE(OP, 10) \
break; \
case 11: \
CASE(OP, 11) \
break; \
case 12: \
CASE(OP, 12) \
break; \
case 13: \
CASE(OP, 13) \
break; \
case 16: \
CASE(OP, 16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_div_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
b2, b3, c0, c1, c2, c3);
SWITCH_DTYPE(div, dType)
}
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_add_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
SWITCH_DTYPE(add, dType)
}
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
b2, b3, c0, c1, c2, c3);
if (dType == 1) {
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 3) {
_pow_kernel<int8_t><<<gridsize, blocksize>>>(
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 10) {
int a_size = a0 * a1 * a2 * a3;
int b_size = b0 * b1 * b2 * b3;
int c_size = c0 * c1 * c2 * c3;
vector<float> a_float(a_size);
vector<float> b_float(b_size);
vector<float> c_float(c_size);
for (int i = 0; i < a_size; ++i) {
a_float[i] = __half2float(((half *)a)[i]);
}
for (int i = 0; i < b_size; ++i) {
b_float[i] = __half2float(((half *)b)[i]);
}
_pow_kernel<float><<<gridsize, blocksize>>>(
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
for (int i = 0; i < c_size; ++i) {
((half *)c)[i] = __float2half(c_float[i]);
}
} else {
IT_TODO_HALT();
}
}
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_less_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
SWITCH_DTYPE(less, dType)
}
}; // namespace infini

View File

@ -25,12 +25,12 @@ class ExpandCuda : public CudaKernelWithoutConfig {
inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i];
}
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
const int dType = op->getDType().getIndex();
expandKernel(dType, inputData, outputData, nDims, outputsize,
inputShape, outputShape);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda,
"Expand_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Expand, ExpandCuda, "Expand_CUDA");
}; // namespace infini

View File

@ -1,12 +1,14 @@
#include "core/common.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _expandKernel(float *input, float *output, int nDims,
template <class T>
__global__ void _expandKernel(void *input, void *output, int nDims,
int outputsize, infini::SmallArray inputShape,
infini::SmallArray outputShape) {
@ -33,17 +35,64 @@ __global__ void _expandKernel(float *input, float *output, int nDims,
temp *= inputShape.data[i];
v = v / outputShape.data[i];
}
output[outputIdx] = input[inputIdx];
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
}
}
namespace infini {
void expandKernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape) {
#define CASE(T) \
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
input, output, nDims, outputsize, inputShape, outputShape);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
inputShape, outputShape);
SWITCH_DTYPE(dType)
}
} // namespace infini

View File

@ -8,6 +8,7 @@ class ExtendCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ExtendObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto inData = op->getInputs(0)->getRawDataPtr<float *>();
auto outData = op->getOutputs()[0]->getRawDataPtr<float *>();
int blockSize = 1;
@ -22,6 +23,5 @@ class ExtendCuda : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Extend, DataType::Float32, ExtendCuda,
"Extend_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Extend, ExtendCuda, "Extend_CUDA");
} // namespace infini

View File

@ -15,12 +15,23 @@ class GatherCuda : public CudaKernelWithoutConfig {
GatherMetaData metaData;
initGatherMetaData(metaData, op);
auto inData = input->getRawDataPtr<float *>();
auto outData = op->getOutput()->getRawDataPtr<float *>();
gather_kernel(inData, outData, metaData, op->getOutput()->size());
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
if (op->getDType() == DataType::Float32) {
gather_kernel<float>((float *)inputData, (float *)outputData,
metaData, op->getOutput()->size());
} else if (op->getDType() == DataType::Float16) {
gather_kernel<half>((half *)inputData, (half *)outputData, metaData,
op->getOutput()->size());
} else if (op->getDType() == DataType::Int8) {
gather_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData,
metaData, op->getOutput()->size());
} else {
IT_ASSERT(false);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Gather, DataType::Float32, GatherCuda,
"Gather_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Gather, GatherCuda, "Gather_CUDA");
} // namespace infini

View File

@ -28,27 +28,32 @@ __device__ T gatheredOffset2Offset(int gOffset,
return offset;
}
template <typename T>
__global__ void _gather_kernel(float *in, float *out,
template <typename dataT, typename T>
__global__ void _gather_kernel(dataT *in, dataT *out,
infini::GatherMetaData metaData, size_t num) {
T tid = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
while (tid < num) {
if (tid < num) {
T offset = gatheredOffset2Offset<T>(tid, metaData);
out[tid] = in[offset];
tid += stride;
}
}
namespace infini {
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
template <typename T>
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.indexType == DataType::Int64) {
_gather_kernel<int64_t>
_gather_kernel<T, int64_t>
<<<gridSize, blockSize>>>(in, out, metaData, num);
} else {
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
_gather_kernel<T, int><<<gridSize, blockSize>>>(in, out, metaData, num);
}
}
template void gather_kernel<float>(float *in, float *out,
GatherMetaData metaData, size_t num);
template void gather_kernel<half>(half *in, half *out, GatherMetaData metaData,
size_t num);
template void gather_kernel<int8_t>(int8_t *in, int8_t *out,
GatherMetaData metaData, size_t num);
} // namespace infini

View File

@ -21,8 +21,7 @@ class GatherElementsCuda : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Float32,
GatherElementsCuda, "GatherELements_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Int32,
GatherElementsCuda, "GatherElements_CUDA_Int32");
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, GatherElementsCuda,
"GatherELements_CUDA");
} // namespace infini

View File

@ -24,8 +24,10 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
int dimsize = dims[op->getAxis()];
int size = op->getOutput(0)->size();
int scaleSize = op->getInputs(1)->size();
if (op->getDType() == DataType::Float32) {
if (op->numInputs() == 3) {
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const biasData =
(op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1);
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
@ -36,10 +38,27 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
scaleSize, dimsize, stride, (float *)outputData);
}
} else if (op->getDType() == DataType::Float16) {
if (op->numInputs() == 3) {
void *const biasData =
(op->getInputs(2)->getRawDataPtr<void *>());
int biasSize = op->getInputs(2)->size();
// printf("kernel bias:true:%d\n", 1);
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
scaleSize, dimsize, stride, (half *)outputData,
(half *)biasData, biasSize);
} else {
// printf("kernel bias:false:%d\n", 0);
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
scaleSize, dimsize, stride, (half *)outputData);
}
} else {
IT_ASSERT(false);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32,
LayerNormCuda, "LayerNorm_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, LayerNormCuda,
"LayerNorm_CUDA");
}; // namespace infini

View File

@ -1,43 +1,41 @@
#include "cuda/cuda_common.h"
#include <cub/cub.cuh>
template <int BLOCK_DIM>
template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride, float *output,
const float eps, int scaleSize, const float *bias,
int biasSize) {
void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
const int stride, T *output, const T eps,
int scaleSize, const T *bias, int biasSize) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f;
T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
__shared__ T mu;
T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize;
mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
}
__syncthreads();
float sigma2Partial = 0.0f;
T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ float sigma2;
float sigma2Block =
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
__shared__ T sigma2;
T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize;
sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
}
__syncthreads();
if (biasSize == dimsize) {
@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
mu) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM];
}
} else {
@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
mu) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM];
}
}
@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
mu) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[0];
}
} else {
@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
mu) /
sqrt(sigma2 + eps) +
mu) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
bias[0];
}
}
}
}
//-----------------
template <int BLOCK_DIM>
template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void blockLaynormKernel(const float *input, const float *scale,
const int dimsize, const int stride, float *output,
const float eps, int scaleSize) {
void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
const int stride, T *output, const T eps,
int scaleSize) {
// len(scale) = len(bias) = dimsize
int tmp = blockIdx.x % stride;
int tid = (blockIdx.x - tmp) * dimsize + tmp;
float muPartial = 0.0f;
T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float mu;
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
__shared__ T mu;
T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
mu = muBlock / dimsize;
mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
}
__syncthreads();
float sigma2Partial = 0.0f;
T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
sigma2Partial +=
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
}
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
__shared__ float sigma2;
float sigma2Block =
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
__shared__ T sigma2;
T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
sigma2 = sigma2Block / dimsize;
sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
}
__syncthreads();
if (scaleSize == dimsize) {
@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
sqrt(sigma2 + eps);
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
sqrt(sigma2 + eps);
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
}
}
}
@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) {
}
return val;
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale,
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const T *input, const T *scale,
const int dimsize, const int stride,
float *output, const float eps, int scaleSize,
int otherSize, const float *bias,
int biasSize) {
T *output, const T eps, int scaleSize,
int otherSize, const T *bias, int biasSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y];
__shared__ T muTotal[BLOCK_DIM_y];
__shared__ T sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f;
T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
}
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize;
muTotal[threadIdx.y] =
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
//--------------------------------------------
float sigma2Partial = 0.0f;
T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial +=
@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
muTotal[threadIdx.y]);
}
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
sigma2Total[threadIdx.y] =
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
//--------------------------------------------
if (biasSize == dimsize) {
@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
muTotal[threadIdx.y]) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM_x];
}
} else {
@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[0] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
muTotal[threadIdx.y]) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[threadIdx.x + ph * BLOCK_DIM_x];
}
}
@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
muTotal[threadIdx.y]) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[0];
}
} else {
@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
scale[0] *
(input[tid +
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps) +
muTotal[threadIdx.y]) *
static_cast<T>(__fdividef(
1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps)))) +
bias[0];
}
}
}
}
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const float *input, const float *scale,
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void warpLaynormKernel(const T *input, const T *scale,
const int dimsize, const int stride,
float *output, const float eps, int scaleSize,
T *output, const T eps, int scaleSize,
int otherSize) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
if (otherIdx < otherSize) {
__shared__ float muTotal[BLOCK_DIM_y];
__shared__ float sigma2Total[BLOCK_DIM_y];
__shared__ T muTotal[BLOCK_DIM_y];
__shared__ T sigma2Total[BLOCK_DIM_y];
float muPartial = 0.0f;
T muPartial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
}
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
if (threadIdx.x == 0)
muTotal[threadIdx.y] = muPartial / dimsize;
muTotal[threadIdx.y] =
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
//--------------------------------------------
float sigma2Partial = 0.0f;
T sigma2Partial = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sigma2Partial +=
@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
muTotal[threadIdx.y]);
}
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
if (threadIdx.x == 0)
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
sigma2Total[threadIdx.y] =
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
//--------------------------------------------
if (scaleSize == dimsize) {
@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[threadIdx.x + ph * BLOCK_DIM_x] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps);
muTotal[threadIdx.y]) *
static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps))));
}
} else {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
scale[0] *
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
muTotal[threadIdx.y]) /
sqrt(sigma2Total[threadIdx.y] + eps);
muTotal[threadIdx.y]) *
static_cast<T>(
__fdividef(1.0F, sqrt(static_cast<float>(
sigma2Total[threadIdx.y] + eps))));
}
}
}
@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<1024>
blockLaynormKernel<float, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) {
@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) {
@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 7) {
@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else {
@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
}
@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>(
blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else {
int BLOCK_DIM_x = 4;
@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
}
}
//-----------------
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output, const half *bias, int biasSize) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
}
}
void LaynormKernel(const half *input, const half *scale, const half eps,
int size, int scaleSize, const int dimsize, const int stride,
half *output) {
int num_block = size / dimsize;
if (dimsize > 1024) {
int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
}
}

View File

@ -2,6 +2,7 @@
#include "core/kernel.h"
#include "cuda/cuda_expand.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
namespace infini {
@ -48,11 +49,12 @@ class matmulCublas : public Kernel {
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
ldc = n;
float alpha = 1.f, beta = 0.f;
if (op->numInputs() == 2) { // no bias
beta = 0.f;
} else { // broadcast bias to output
beta = 1.f;
float alpha_naive = 1.f, beta_naive = 0.f;
auto dataType = op->getDType();
auto cuDataType = cublasDataTypeConvert(dataType);
IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype.");
if (op->numInputs() == 3) { // have bias
beta_naive = 1.f;
auto inC = op->getInputs(2);
auto out = op->getOutput();
SmallArray inputShape, outputShape;
@ -69,8 +71,9 @@ class matmulCublas : public Kernel {
if (i >= offset)
inputShape.data[i] = inC->getDims()[i - offset];
}
expandKernel(inC->getRawDataPtr<float *>(),
out->getRawDataPtr<float *>(), nDims, outputsize,
const int dType = dataType.getIndex();
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize,
inputShape, outputShape);
}
// TODO:use compute type
@ -89,16 +92,38 @@ class matmulCublas : public Kernel {
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
? 0 // Broadcast the batch dimension if batch size is 1
: n * k;
if (dataType == DataType::Float16) {
half alpha_half = static_cast<half>(alpha_naive);
half beta_half = static_cast<half>(beta_naive);
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmStridedBatchedEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
cuDataType, (cublasGemmAlgo_t)record->algo);
}
} else {
if (dataType == DataType::Float16) {
half alpha_half = static_cast<half>(alpha_naive);
half beta_half = static_cast<half>(beta_naive);
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_half, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_half,
outData, cuDataType, ldc, cuDataType,
(cublasGemmAlgo_t)record->algo);
} else {
stat = cublasGemmEx(
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
&alpha_naive, inBData, cuDataType, ldb,
inAData, cuDataType, lda, &beta_naive,
outData, cuDataType, ldc, cuDataType,
(cublasGemmAlgo_t)record->algo);
}
}
// if (stat != CUBLAS_STATUS_SUCCESS)
// cout << cublasGetErrorString(stat);
@ -140,8 +165,9 @@ class matmulCublas : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::MatMul, DataType::Float32, matmulCublas,
"Matmul_cuBLAS_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::MatMul, matmulCublas,
"Matmul_cuBLAS_CUDA");
REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json);
}; // namespace infini

View File

@ -229,9 +229,8 @@ class MemboundTVMExtractSource : public Kernel {
}
};
// REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
// MemboundTVMExtractSource,
// "Memobund_TVM_Ansor_extract_source");
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMExtractSource,
"Memobund_TVM_Ansor_extract_source");
}; // namespace infini
#endif

View File

@ -216,9 +216,9 @@ class MemboundTVMPackedFunction : public Kernel {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
MemboundTVMPackedFunction,
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMPackedFunction,
"Memobund_TVM_Ansor_packed_funciton");
}; // namespace infini
#endif

View File

@ -39,10 +39,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda,
"Slice__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda,
"Slice__CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
"Pad__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Slice, SliceCuda, "Slice__CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Pad, PadCuda, "Pad__CUDA");
} // namespace infini

View File

@ -1,6 +1,7 @@
#include "core/data_type.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_pad_slice.h"
#include "cuda/cuda_utility.h"
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
TransMetaData metaData,
@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
}
template <typename T>
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
int nDims, int num, bool isPad) {
__global__ void _pad_slice_kernel(void *part, void *whole,
TransMetaData metaData, int nDims, int num,
bool isPad) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num)
if (tid >= num) {
return;
}
int stride = blockDim.x * gridDim.x;
while (tid < num) {
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
if (isPad)
if (offset < 0)
whole[tid] = 0;
else
whole[tid] = part[offset];
else if (offset >= 0)
part[offset] = whole[tid];
if (isPad) {
if (offset < 0) {
((T *)whole)[tid] = static_cast<T>(0.f);
} else {
((T *)whole)[tid] = ((T *)part)[offset];
}
} else if (offset >= 0) {
((T *)part)[offset] = ((T *)whole)[tid];
}
tid += stride;
}
}
namespace infini {
#define CASE(T) \
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
partData, wholeData, metadata, nDims, num, isPad);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
void pad_slice_kernel(void *partData, void *wholeData,
const TransMetaData &metadata, int nDims, int num,
bool isPad) {
int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize;
if (metadata.DType == DataType::Int64.getIndex()) {
_pad_slice_kernel<int64_t>
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
metadata, nDims, num, isPad);
} else if (metadata.DType == DataType::Float32.getIndex()) {
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
}
int dType = metadata.DType;
SWITCH_DTYPE(dType)
}
} // namespace infini

View File

@ -8,6 +8,7 @@ class poolingCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -76,8 +77,9 @@ class avgPoolCudnn : public poolingCudnn {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn,
"MaxPool_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, DataType::Float32,
avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, maxPoolCudnn,
"MaxPool_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, avgPoolCudnn,
"AvgPool_cuDNN_CUDA");
}; // namespace infini

View File

@ -40,8 +40,7 @@ class RecvNCCL : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL,
"Recv_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Recv, RecvNCCL, "Recv_NCCL_CUDA");
} // namespace infini
#endif

View File

@ -1,6 +1,7 @@
#include "operators/reduce.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
namespace infini {
class ReduceCudnnBase : public CudaKernelWithoutConfig {
@ -46,12 +47,12 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
if (nInDims > 3) {
checkCudnnError(cudnnSetTensorNdDescriptor(
inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray));
checkCudnnError(
cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims,
outDimArray, outStrideArray));
inDesc, cudnnDataType, nInDims, inDimArray, inStrideArray));
checkCudnnError(cudnnSetTensorNdDescriptor(
outDesc, cudnnDataType, nInDims, outDimArray, outStrideArray));
} else {
int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1};
for (int i = 0; i < nInDims; ++i) {
@ -62,20 +63,19 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
}
checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1],
inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, idims[0], idims[1],
idims[2], idims[3]));
checkCudnnError(cudnnSetTensor4dDescriptor(
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0],
odims[1], odims[2], odims[3]));
outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, odims[0], odims[1],
odims[2], odims[3]));
}
// get reduce descriptor
cudnnReduceTensorDescriptor_t reduceDesc;
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
checkCudnnError(cudnnSetReduceTensorDescriptor(
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
CUDNN_32BIT_INDICES));
reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
// get workspace
size_t workspaceSize = 0;
@ -120,8 +120,9 @@ class ReduceSumCudnn : public ReduceCudnnBase {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32,
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32,
ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, ReduceMeanCudnn,
"ReduceMean_cuDNN_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, ReduceSumCudnn,
"ReduceSum_cuDNN_CUDA");
}; // namespace infini

View File

@ -11,19 +11,12 @@ class CopyCuda : public CudaKernelWithoutConfig {
}
};
// reshape/flatten/identity all act as copying from input to output.
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
"Reshape_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
"Reshape_CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
"Reshape_CUDA_Int32");
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
"Flatten_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda,
"Squeeze_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda,
"Unsqueeze_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
"Identity_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, CopyCuda, "Reshape_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, CopyCuda, "Flatten_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Identity, CopyCuda, "Identity_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, CopyCuda, "Squeeze_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, CopyCuda, "Unsqueeze_CUDA");
} // namespace infini

View File

@ -6,6 +6,7 @@ class ResizeCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ResizeObj>(_op);
IT_ASSERT(op->getDType() == DataType::Float32);
auto in = op->getInputs(0);
auto out = op->getOutputs()[0];
@ -48,7 +49,6 @@ class ResizeCuda : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Resize, DataType::Float32, ResizeCuda,
"Resize_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Resize, ResizeCuda, "Resize_CUDA");
} // namespace infini

View File

@ -36,8 +36,7 @@ class SendNCCL : public CudaKernelWithoutConfig {
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL,
"Send_NCCL_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Send, SendNCCL, "Send_NCCL_CUDA");
} // namespace infini
#endif

View File

@ -20,11 +20,17 @@ class SoftmaxCuda : public CudaKernelWithoutConfig {
int stride = op->getInputs(0)->getStride().at(op->getAxis());
int num_blocks = size / dimsize;
if (op->getDType() == DataType::Float32) {
softmax_kernel(num_blocks, (float *)input, (float *)output, size,
dimsize, stride);
} else if (op->getDType() == DataType::Float16) {
softmax_kernel(num_blocks, (half *)input, (half *)output, size,
dimsize, stride);
} else {
IT_ASSERT(false);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda,
"Softmax_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, SoftmaxCuda, "Softmax_CUDA");
} // namespace infini

View File

@ -1,6 +1,5 @@
#include "cuda/cuda_common.h"
#include <cub/cub.cuh>
struct __align__(8) DataMaxSum { // update the global max and sum, store the
// output at max_tmp and sum_tmp
float max_tmp; // store max
@ -16,9 +15,9 @@ __device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a,
return bigger;
}
template <int BLOCK_DIM>
template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
float *__restrict input, float *__restrict output, int size, int dimsize,
T *__restrict input, T *__restrict output, int size, int dimsize,
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
// tid = i(JKS) + j(KS) + k(S) + s
@ -33,15 +32,33 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
dms_partial.max_tmp = -__FLT_MAX__;
dms_partial.sum_tmp = 0.0f;
DataMaxSum dms_input;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
int remain = dimsize % BLOCK_DIM;
int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
dms_input.max_tmp =
input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
input[tid + (threadIdx.x * step + ind) * stride];
dms_input.sum_tmp = 1.0f;
dms_partial = reduce_dms_op(dms_partial,
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
dms_input.max_tmp =
input[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride];
dms_input.sum_tmp = 1.0f;
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
}
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ DataMaxSum dms_total;
@ -53,13 +70,103 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
}
__syncthreads();
//-----------------
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
output[tid + (threadIdx.x * step + ind) * stride] =
__expf(static_cast<float>(
input[tid + (threadIdx.x * step + ind) * stride]) -
dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid +
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
stride] =
__expf(static_cast<float>(
input[tid +
(remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride]) -
dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
}
}
template <typename T, int BLOCK_DIM, int numPerThread>
__global__ void
_blockSoftmaxKernel(T *__restrict input, T *__restrict output, int size,
int dimsize,
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
// tid = i(JKS) + j(KS) + k(S) + s
// blockDim.x = size/dimsize = IKS
// blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s
int tid =
blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
dimsize; // now, tid = i(JKS) + k(S) + s;
int remain = dimsize % BLOCK_DIM;
int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
float dataPerThread[numPerThread];
DataMaxSum dms_partial;
dms_partial.max_tmp = -__FLT_MAX__;
dms_partial.sum_tmp = 0.0f;
DataMaxSum dms_input;
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
dataPerThread[ind] =
input[tid + (threadIdx.x * step + ind) * stride];
dms_input.max_tmp = dataPerThread[ind];
dms_input.sum_tmp = 1.0f;
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
dataPerThread[ind] =
input[tid + (remain * step +
(threadIdx.x - remain) * (step - 1) + ind) *
stride];
dms_input.max_tmp = dataPerThread[ind];
dms_input.sum_tmp = 1.0f;
dms_partial =
reduce_dms_op(dms_partial,
dms_input); // reduce the data to one block
}
}
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ DataMaxSum dms_total;
DataMaxSum dms_block =
BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
dms_total = dms_block;
}
__syncthreads();
//-----------------
if (threadIdx.x < remain) {
for (int ind = 0; ind < step; ind++) {
output[tid + (threadIdx.x * step + ind) * stride] =
__expf(dataPerThread[ind] - dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
} else {
for (int ind = 0; ind < step - 1; ind++) {
output[tid +
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
stride] =
__expf(dataPerThread[ind] - dms_total.max_tmp) *
__fdividef(1.0F, dms_total.sum_tmp);
}
}
}
template <typename T> struct SumOp {
@ -81,14 +188,14 @@ __inline__ __device__ T WarpAllReduce(T val) {
}
return val;
}
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
__global__ void _warpSoftmaxKernel(float *__restrict input,
float *__restrict output, int size,
int dimsize, int stride) {
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y, int numPerThreadx>
__global__ void _warpSoftmaxKernel(T *__restrict input, T *__restrict output,
int size, int dimsize, int stride) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int otherSize = size / dimsize;
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
float dataPerThreadx[numPerThreadx];
if (otherIdx < otherSize) {
__shared__ float max_total[BLOCK_DIM_y];
@ -96,9 +203,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
float max_data = -__FLT_MAX__;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
max_data =
max(max_data,
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]);
dataPerThreadx[ph] =
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
max_data = max(max_data, dataPerThreadx[ph]);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data);
@ -110,9 +217,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
float sum_data = 0.0f;
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
sum_data +=
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
max_total[threadIdx.y]);
dataPerThreadx[ph] =
__expf(dataPerThreadx[ph] - max_total[threadIdx.y]);
sum_data += dataPerThreadx[ph];
}
sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data);
@ -124,9 +231,7 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
max_total[threadIdx.y]) *
__fdividef(1.0F, sum_total[threadIdx.y]);
dataPerThreadx[ph] * __fdividef(1.0F, sum_total[threadIdx.y]);
}
}
}
@ -137,10 +242,35 @@ namespace infini {
void softmax_kernel(int num_blocks, float *input, float *output, int size,
int dimsize, int stride) {
if (dimsize > 1024) {
if (dimsize > 1024 * 128) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<1024>
_blockSoftmaxKernel<float, 1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 64) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 128>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 32) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 64>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 16) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 32>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 4) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 16>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 4>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
@ -149,7 +279,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<32, 32>
_warpSoftmaxKernel<float, 32, 32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
@ -158,7 +288,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<16, 64>
_warpSoftmaxKernel<float, 16, 64, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
@ -167,7 +297,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<8, 128>
_warpSoftmaxKernel<float, 8, 128, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else {
int BLOCK_DIM_x = 4;
@ -176,7 +306,79 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<4, 256>
_warpSoftmaxKernel<float, 4, 256, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
}
}
//------------------
void softmax_kernel(int num_blocks, half *input, half *output, int size,
int dimsize, int stride) {
if (dimsize > 1024 * 128) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 64) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 128>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 32) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 64>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 16) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 32>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 4) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 16>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 1024) {
int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 4>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
} else if (dimsize > 31) {
int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 32, 32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 15) {
int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 16, 64, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else if (dimsize > 7) {
int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 8, 128, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
} else {
int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256;
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 4, 256, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
}
}

View File

@ -7,7 +7,8 @@
namespace infini {
class CudaCompute {
void initComposedTensorMetadata(ComposedTensorMetadata &metadata,
template <typename T>
void initComposedTensorMetadata(ComposedTensorMetadata<T> &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
@ -16,10 +17,10 @@ class CudaCompute {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
}
metadata.data = tensor->getRawDataPtr<float *>();
metadata.data = tensor->getRawDataPtr<T *>();
}
void initElementTensorMetadata(ElementTensorMetadata &metadata,
template <typename T>
void initElementTensorMetadata(ElementTensorMetadata<T> &metadata,
TensorVec tensors, int idx, int dim,
int &dimBgIdx, int &batchCounter) const {
int nTensors = tensors.size();
@ -27,7 +28,7 @@ class CudaCompute {
++batchCounter) {
auto tensor = tensors.at(idx + batchCounter);
auto dimSize = tensor->getDims()[dim];
metadata.data[batchCounter] = tensor->getRawDataPtr<float *>();
metadata.data[batchCounter] = tensor->getRawDataPtr<T *>();
metadata.dimBgNo[batchCounter] = dimBgIdx;
metadata.dimSize[batchCounter] = dimSize;
metadata.nElements[batchCounter] = tensor->size();
@ -36,17 +37,17 @@ class CudaCompute {
}
public:
template <typename T>
void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim,
int nDims, bool isSplit) const {
IT_ASSERT(nDims <= DIM_MAX_SIZE);
ComposedTensorMetadata composedMetadata;
initComposedTensorMetadata(composedMetadata, composedTensor);
ComposedTensorMetadata<T> composedMetadata;
initComposedTensorMetadata<T>(composedMetadata, composedTensor);
int dimBgNo = 0;
int nElemets = elementsTensor.size();
for (int i = 0; i < nElemets; i += BATCH_SIZE) {
ElementTensorMetadata elemMetadata;
ElementTensorMetadata<T> elemMetadata;
int batchCounter = 0;
initElementTensorMetadata(elemMetadata, elementsTensor, i, dim,
dimBgNo, batchCounter);
@ -74,23 +75,38 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
}
}
}
do_compute(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
false);
if (_op->getDType() == DataType::Float32) {
do_compute<float>(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getRank(), false);
} else if (_op->getDType() == DataType::Float16) {
do_compute<half>(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(),
_op->getOutput()->getRank(), false);
} else {
IT_ASSERT(false);
}
}
};
class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(), _op->getInputs(0)->getRank(),
true);
if (_op->getDType() == DataType::Float32) {
do_compute<float>(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getRank(), true);
} else if (_op->getDType() == DataType::Float16) {
do_compute<half>(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(),
_op->getInputs(0)->getRank(), true);
} else {
IT_ASSERT(false);
}
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda,
"Concat_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda,
"Split_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Concat, ConcatCuda, "Concat_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Split, SplitCuda, "Split_CUDA");
} // namespace infini

Some files were not shown because too many files have changed in this diff Show More