forked from jiuyuan/InfiniTensor
Modify kernel registration & support fp16 (#205)
* - Remove dataType from the kernel registration. * - support fp16 for conv * - cpu kernel: adapt the new registration mechanism * modified all register kernel * add where fp16 * add layernorm fp16 * add split_concat fp16 * - element_wise support fp16 * feat: support transpose fp16 * feat: support sliceOp fp16 * - unary support fp16 * - feat: support reduceOp fp16 * feat: support matmulOp/expandOp fp16 * feat: support powOp int8 * add cuda cast & support half-precision for gather * style: fix style * feat:support int8 for gather * style:fix style * modified test_cuda_conv_transposed * fix: fix dist code to support fp16 * fix(graph.cc): fix topo_sort * fix: fix recv and send kernel registration * feat: add field tensors for stub * refactor(frontend): 先排序后构图 Signed-off-by: YdrMaster <ydrml@hotmail.com> * fix: 为中间结果提供tensor到node的mapping * fix (slice): add guard for area out of range * fix: fix matmul fp16 * fix: fix re-dataMalloc for weight tensor and use of naive allocator * feat: add dataType filter for cuda kernel * feat: bang kernel adapt the new registration mechanism * fix: fix some error on mlu * feat: intelcpu kernel adapt the new registration mechanism * feat: modify kernel registration on kunlun * fix intelcpu compiler bug * feat: bang reshape support all dataType * fix: fix bang reduce * fix(all_reduce.cc): fix as reviewer suggessted * fix: fix style and restore unary test codes --------- Signed-off-by: YdrMaster <ydrml@hotmail.com> Co-authored-by: xgqdut2016 <kenan_gewei@163.com> Co-authored-by: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com> Co-authored-by: zhangyunze <z13785159769@163.com> Co-authored-by: OdinaryWord <sx-hz@163.com> Co-authored-by: YdrMaster <ydrml@hotmail.com> Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
parent
58993d4339
commit
51086d2b8d
|
@ -137,7 +137,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
place[node.output[0]] = Shard(list(perm).index(plc.dim))
|
||||
|
||||
def shard_node(node: NodeProto):
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax"]:
|
||||
if node.op_type in ["Relu", "Tanh", "Softmax", "Cast"]:
|
||||
place[node.output[0]] = place[node.input[0]]
|
||||
elif node.op_type in ["Where"]:
|
||||
place[node.output[0]] = place[node.input[1]]
|
||||
|
@ -177,7 +177,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
input in data for input in node.input
|
||||
):
|
||||
# FIXME(constroy): the last MatMul should not be sharded as TP.
|
||||
if node.output[0] in output:
|
||||
if (
|
||||
node.output[0] in output
|
||||
or (
|
||||
index + 1 < len(model.graph.node)
|
||||
and model.graph.node[index + 1].output[0]
|
||||
)
|
||||
in output
|
||||
):
|
||||
continue
|
||||
groups = 1
|
||||
# If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups
|
||||
|
|
|
@ -30,7 +30,6 @@ class Kernel {
|
|||
public:
|
||||
Kernel() {}
|
||||
virtual ~Kernel() {}
|
||||
|
||||
/**
|
||||
* @param op The operator to be executed.
|
||||
* @param record The parameters for kernel execution. If extra parameters
|
||||
|
@ -130,15 +129,16 @@ class CpuKernelWithoutConfig : public Kernel {
|
|||
|
||||
} // namespace infini
|
||||
|
||||
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, cnt) \
|
||||
#define _REGISTER_KERNEL_1(device, opType, kernel, name, cnt) \
|
||||
namespace infini { \
|
||||
static const bool _CAT(_register_kernel_, cnt) = \
|
||||
KernelRegistry::getInstance().registerKernel( \
|
||||
KernelAttrs{device, opType, dataType}, new kernel(), name); \
|
||||
KernelRegistry::getInstance().registerKernel(KernelAttrs{device, \
|
||||
opType}, \
|
||||
new kernel(), name); \
|
||||
}
|
||||
|
||||
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
|
||||
#define REGISTER_KERNEL(device, opType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, kernel, name, __COUNTER__)
|
||||
|
||||
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
|
||||
namespace infini { \
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "core/tensor.h"
|
||||
|
||||
namespace infini {
|
||||
using KernelAttrs = std::tuple<Device, OpType::underlying_t, DataType>;
|
||||
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
|
||||
|
||||
struct OpPerfKey {
|
||||
HashType hash;
|
||||
|
@ -90,6 +90,7 @@ class OperatorObj : public Object {
|
|||
OpType getOpType() const { return type; }
|
||||
// HACK: set correct data type
|
||||
DataType getDType() const { return getInputs(0)->getDType(); }
|
||||
DataType getOutDType() const { return getOutput()->getDType(); }
|
||||
virtual int numInputs() const = 0;
|
||||
virtual int numOutputs() const = 0;
|
||||
|
||||
|
|
|
@ -44,8 +44,16 @@ class TensorObj : public TensorBaseObj {
|
|||
bool isOutput() const { return tensorType == TensorType::output; }
|
||||
bool isOthers() const { return tensorType == TensorType::others; }
|
||||
void setWeight() { tensorType = TensorType::weight; }
|
||||
void setInput() { tensorType = TensorType::input; }
|
||||
void setOutput() { tensorType = TensorType::output; }
|
||||
void setInput() {
|
||||
if (!this->isWeight()) {
|
||||
tensorType = TensorType::input;
|
||||
}
|
||||
}
|
||||
void setOutput() {
|
||||
if (!this->isWeight()) {
|
||||
tensorType = TensorType::output;
|
||||
}
|
||||
}
|
||||
string tensorTypeToString() const {
|
||||
switch (tensorType) {
|
||||
case TensorType::weight:
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3);
|
||||
void div_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
void add_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
namespace infini {
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape);
|
||||
void expandKernel(int dType, void *input, void *output, int nDims,
|
||||
int outputsize, SmallArray inputShape,
|
||||
SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -8,4 +8,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
void LaynormKernel(const float *input, const float *scale, const float eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
float *output);
|
||||
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
half *output, const half *bias, int biasSize);
|
||||
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
half *output);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,4 +3,6 @@
|
|||
namespace infini {
|
||||
void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
||||
int dimsize, int stride);
|
||||
}
|
||||
void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
||||
int dimsize, int stride);
|
||||
} // namespace infini
|
||||
|
|
|
@ -8,8 +8,8 @@ const int DIM_MAX_SIZE = 8;
|
|||
// Concat operator acts like element tensors composing to one big tensor,and
|
||||
// split operator acts like one big tensor being composed by element
|
||||
// tensors.
|
||||
struct ElementTensorMetadata {
|
||||
float *data[BATCH_SIZE];
|
||||
template <typename T> struct ElementTensorMetadata {
|
||||
T *data[BATCH_SIZE];
|
||||
int dimBgNo[BATCH_SIZE]; // the dimention begin no of the element tensor in
|
||||
// the composed tensor.
|
||||
int dimSize[BATCH_SIZE]; // the dimention size of the element tensor.
|
||||
|
@ -20,16 +20,17 @@ struct ElementTensorMetadata {
|
|||
data[i], dimBgNo[i], dimSize[i], nElements[i]);
|
||||
}
|
||||
};
|
||||
|
||||
struct ComposedTensorMetadata {
|
||||
template <typename T> struct ComposedTensorMetadata {
|
||||
int dimSize[DIM_MAX_SIZE];
|
||||
int stride[DIM_MAX_SIZE];
|
||||
float *data;
|
||||
T *data;
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void split_concat_kernel(const ElementTensorMetadata &eleMeta,
|
||||
const ComposedTensorMetadata &compMeta, int dim,
|
||||
void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
|
||||
const ComposedTensorMetadata<float> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit);
|
||||
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
||||
const ComposedTensorMetadata<half> &compMeta, int dim,
|
||||
int batchSize, int nDims, bool isSplit);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void transpose_kernel(float *input, float *output, int nDims, int size,
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,48 +3,21 @@
|
|||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(float *input, float *output, size_t num);
|
||||
void relu_kernel(float *input, float *output, size_t num);
|
||||
void sigmoid_kernel(float *input, float *output, size_t num);
|
||||
void tanh_kernel(float *input, float *output, size_t num);
|
||||
void abs_kernel(float *input, float *output, size_t num);
|
||||
void sqrt_kernel(float *input, float *output, size_t num);
|
||||
void neg_kernel(float *input, float *output, size_t num);
|
||||
void gelu_kernel(float *input, float *output, size_t num);
|
||||
void erf_kernel(float *input, float *output, size_t num);
|
||||
void hard_sigmoid_kernel(float *input, float *output, size_t num);
|
||||
void hard_swish_kernel(float *input, float *output, size_t num);
|
||||
template <typename T> void softmax_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void relu_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void abs_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void sqrt_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void neg_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void gelu_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void erf_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
||||
|
||||
void unary_kernel(const Operator &_op) {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
|
||||
template <typename INPUT, typename OUTPUT>
|
||||
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
||||
|
||||
size_t num = op->getOutput()->size();
|
||||
if (op->getOpType() == OpType::Softmax)
|
||||
softmax_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Relu)
|
||||
relu_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Sigmoid)
|
||||
sigmoid_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::HardSigmoid)
|
||||
hard_sigmoid_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::HardSwish)
|
||||
hard_swish_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Tanh)
|
||||
tanh_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Abs)
|
||||
abs_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Sqrt)
|
||||
sqrt_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Gelu)
|
||||
gelu_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Neg)
|
||||
neg_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Erf)
|
||||
erf_kernel(inputData, outputData, num);
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
void unary_kernel(const Operator &_op);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,11 +1,29 @@
|
|||
#pragma once
|
||||
#include "core/tensor.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void cudaPrintFloat(float *x, int len);
|
||||
|
||||
void cudaPrintTensor(const Tensor &tensor) {
|
||||
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
|
||||
}
|
||||
void cudaPrintTensor(const Tensor &tensor);
|
||||
|
||||
cudnnDataType_t cudnnDataTypeConvert(DataType dataType);
|
||||
cudaDataType cublasDataTypeConvert(DataType);
|
||||
|
||||
template <int index> struct DT_CUDA {};
|
||||
template <> struct DT_CUDA<0> { using t = bool; };
|
||||
template <> struct DT_CUDA<1> { using t = float; };
|
||||
template <> struct DT_CUDA<2> { using t = unsigned char; };
|
||||
template <> struct DT_CUDA<3> { using t = char; };
|
||||
template <> struct DT_CUDA<4> { using t = unsigned short; };
|
||||
template <> struct DT_CUDA<5> { using t = short; };
|
||||
template <> struct DT_CUDA<6> { using t = int; };
|
||||
template <> struct DT_CUDA<7> { using t = long long; };
|
||||
template <> struct DT_CUDA<9> { using t = bool; };
|
||||
template <> struct DT_CUDA<10> { using t = half; };
|
||||
template <> struct DT_CUDA<11> { using t = double; };
|
||||
template <> struct DT_CUDA<12> { using t = unsigned int; };
|
||||
template <> struct DT_CUDA<13> { using t = unsigned long long; };
|
||||
template <> struct DT_CUDA<16> { using t = nv_bfloat16; };
|
||||
} // namespace infini
|
|
@ -3,10 +3,15 @@
|
|||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void whereKernel(const float *inputX, const float *inputY,
|
||||
const uint8_t *condition, float *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
|
||||
void whereKernel(const half *inputX, const half *inputY,
|
||||
const uint8_t *condition, half *output, int nDims,
|
||||
int outputsize, SmallArray inputXShape, SmallArray inputYShape,
|
||||
SmallArray conditionShape, SmallArray outputShape, int xSize,
|
||||
int ySize, int cSize);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -53,7 +53,8 @@ inline void initGatherMetaData(GatherMetaData &metaData,
|
|||
metaData.inStride[i] = in->getStride()[i];
|
||||
}
|
||||
}
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
|
||||
template <typename T>
|
||||
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num);
|
||||
|
||||
void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
||||
size_t num);
|
||||
|
|
|
@ -91,6 +91,12 @@ template <int val> class ValGenerator : public DataGenerator {
|
|||
fill<uint32_t>(data, size);
|
||||
}
|
||||
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
||||
void fill_fp16(uint16_t *data, size_t size) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float x = 1.0f * val;
|
||||
data[i] = float_to_fp16(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
typedef ValGenerator<1> OneGenerator;
|
||||
typedef ValGenerator<0> ZeroGenerator;
|
||||
|
|
|
@ -37,7 +37,7 @@ class OnnxStub:
|
|||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ModelProto, runtime):
|
||||
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
|
@ -51,13 +51,43 @@ class OnnxStub:
|
|||
|
||||
self.inputs: Dict[str, backend.Tensor] = {}
|
||||
self.outputs: Dict[str, backend.Tensor] = {}
|
||||
self.tensors: Dict[str, backend.Tensor] = {}
|
||||
self.tensor_node_map: Dict[str, str] = {}
|
||||
self.initializer: Dict[int, TensorProto] = {}
|
||||
self.use_naive_allocator: bool = use_naive_allocator
|
||||
# try:
|
||||
# model = infer_shapes(model)
|
||||
# except:
|
||||
# warnings.warn("infer_shapes failed.")
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
# 处理重名和匿名算子
|
||||
names = {}
|
||||
for node in model.graph.node:
|
||||
if node.name == "":
|
||||
node.name = "missing_name(" + node.op_type + ")"
|
||||
if node.name in names:
|
||||
names[node.name] += 1
|
||||
node.name += "_" + str(names[node.name])
|
||||
else:
|
||||
names[node.name] = 0
|
||||
# 拓扑排序
|
||||
sorted_nodes = []
|
||||
known_edge = set(t.name for t in model.graph.input)
|
||||
known_edge.update(t.name for t in model.graph.initializer)
|
||||
while len(sorted_nodes) < len(model.graph.node):
|
||||
updated = False
|
||||
for i, node in enumerate(model.graph.node):
|
||||
if all(t in known_edge for t in node.input):
|
||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||
sorted_nodes.append(i)
|
||||
known_edge.update(node.output)
|
||||
for t_ in node.output:
|
||||
self.tensor_node_map[t_] = node.name
|
||||
updated = True
|
||||
if not updated:
|
||||
raise Exception("Graph has cycle")
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
|
@ -82,17 +112,8 @@ class OnnxStub:
|
|||
)
|
||||
tensors[output.name].set_output()
|
||||
|
||||
node_name = []
|
||||
new_node_name = []
|
||||
for node in model.graph.node:
|
||||
node_name.append(node.name)
|
||||
node_list = model.graph.node
|
||||
while len(node_list) != 0:
|
||||
for node in model.graph.node:
|
||||
if node.name not in node_list:
|
||||
continue
|
||||
if _analyse_node(node, tensors):
|
||||
continue
|
||||
for node_idx in sorted_nodes:
|
||||
node = model.graph.node[node_idx]
|
||||
if node.op_type == "Conv":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
@ -200,8 +221,7 @@ class OnnxStub:
|
|||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
)
|
||||
(alpha, beta, transA, transB) = (
|
||||
attributes[name]
|
||||
for name in ["alpha", "beta", "transA", "transB"]
|
||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||
)
|
||||
# FIXME unsupport attributes: `alpha` `beta`
|
||||
assert alpha == 1.0
|
||||
|
@ -637,9 +657,7 @@ class OnnxStub:
|
|||
tensors[node.output[0]] = self.handler.concat(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis")
|
||||
),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "AttentionKVCache":
|
||||
tensors[node.output[0]] = self.handler.attentionKVCache(
|
||||
|
@ -709,19 +727,11 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
# NOTE(constroy): `axes` is an attribute until opset version 13.
|
||||
next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
),
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||
None,
|
||||
),
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "keepdims"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "keepdims"),
|
||||
1,
|
||||
)
|
||||
!= 0,
|
||||
|
@ -749,9 +759,7 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[3]])
|
||||
if len(node.input) > 3
|
||||
else None,
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
)
|
||||
elif node.op_type == "Dropout":
|
||||
for name, tensor in zip(
|
||||
|
@ -759,9 +767,7 @@ class OnnxStub:
|
|||
self.handler.dropout(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1])
|
||||
if len(node.output) > 1
|
||||
else None,
|
||||
tensors.get(node.output[1]) if len(node.output) > 1 else None,
|
||||
_parse_data(data[node.input[1]])[0]
|
||||
if len(node.input) > 1
|
||||
else 0.5,
|
||||
|
@ -865,11 +871,7 @@ class OnnxStub:
|
|||
0,
|
||||
)
|
||||
destination = next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "destination"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "destination"),
|
||||
0,
|
||||
)
|
||||
|
||||
|
@ -885,11 +887,7 @@ class OnnxStub:
|
|||
0,
|
||||
)
|
||||
destination = next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "destination"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "destination"),
|
||||
0,
|
||||
)
|
||||
|
||||
|
@ -943,7 +941,8 @@ class OnnxStub:
|
|||
node, {"alpha": 0.0001, "beta": 0.75, "bias": 1.0, "size": 1}
|
||||
)
|
||||
(alpha, beta, bias, size) = (
|
||||
attributes[name] for name in ["alpha", "beta", "bias", "size"]
|
||||
attributes[name]
|
||||
for name in ["alpha", "beta", "bias", "size"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.lrn(
|
||||
tensors[node.input[0]],
|
||||
|
@ -955,14 +954,11 @@ class OnnxStub:
|
|||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
new_node_name.append(node.name)
|
||||
# update the node_list
|
||||
node_list = list(set(node_name) - set(new_node_name))
|
||||
|
||||
################################
|
||||
# Allocate memory space for data
|
||||
################################
|
||||
self.handler.data_malloc()
|
||||
self.handler.data_malloc(self.use_naive_allocator)
|
||||
|
||||
#################################
|
||||
# Copy in data to tensor objects
|
||||
|
@ -993,6 +989,9 @@ class OnnxStub:
|
|||
# assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
obj.copyin_numpy(to_array(tensor))
|
||||
|
||||
for name, obj in tensors.items():
|
||||
self.tensors[name] = obj
|
||||
|
||||
for output in model.graph.output:
|
||||
self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
|
@ -1335,7 +1334,7 @@ class OnnxStub:
|
|||
return ctx.build(name)
|
||||
|
||||
def init(self) -> None:
|
||||
self.handler.data_malloc()
|
||||
self.handler.data_malloc(self.use_naive_allocator)
|
||||
|
||||
def optimize(self) -> None:
|
||||
self.handler.optimize()
|
||||
|
@ -1351,7 +1350,7 @@ class OnnxStub:
|
|||
oldTensor = self.inputs[oldInput]
|
||||
self.handler.change_shape(newInput, oldTensor.fuid())
|
||||
self.handler.shape_infer()
|
||||
self.handler.data_malloc()
|
||||
self.handler.data_malloc(self.use_naive_allocator)
|
||||
|
||||
def getShape(self, name: str) -> List[int]:
|
||||
if name in self.inputs:
|
||||
|
@ -1414,10 +1413,3 @@ def _parse_data_fp16(tensor: TensorProto):
|
|||
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
||||
|
||||
def _analyse_node(node: NodeProto, tensors) -> bool:
|
||||
for i in node.input:
|
||||
if i not in tensors:
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -16,8 +16,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -87,48 +87,33 @@ string GraphObj::toString() const {
|
|||
}
|
||||
|
||||
bool GraphObj::topo_sort() {
|
||||
if (this->sorted)
|
||||
if (this->sorted) {
|
||||
return true;
|
||||
|
||||
// std::unordered_set<Tensor> inputs;
|
||||
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||
}
|
||||
std::vector<Operator> sorted;
|
||||
|
||||
while (!waiting.empty()) {
|
||||
std::unordered_set<OperatorObj *> flags;
|
||||
sorted.reserve(ops.size());
|
||||
flags.reserve(ops.size());
|
||||
while (sorted.size() < ops.size()) {
|
||||
// Any node is move to sorted in this loop.
|
||||
auto modified = false;
|
||||
// Find head nodes.
|
||||
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||
const auto &this_inputs = (*it)->getInputs();
|
||||
// If none of the input tensors is in waiting list,
|
||||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getSource();
|
||||
return src // If the source node is in the waiting
|
||||
// list, means that this node is not the
|
||||
// head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
// This tensor has no source node,
|
||||
// it must be a input tensor.
|
||||
: (/*inputs.insert(input),*/ true);
|
||||
});
|
||||
// Moves head node to sorted.
|
||||
if (is_head) {
|
||||
for (auto const &op : ops) {
|
||||
if (auto const &inputs = op->getInputs();
|
||||
flags.find(op.get()) == flags.end() &&
|
||||
std::all_of(inputs.begin(), inputs.end(),
|
||||
[&flags](auto const &input) {
|
||||
auto ptr = input->getSource().get();
|
||||
return !ptr || flags.find(ptr) != flags.end();
|
||||
})) {
|
||||
modified = true;
|
||||
sorted.emplace_back(std::move(*it));
|
||||
it = waiting.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
sorted.emplace_back(op);
|
||||
flags.insert(op.get());
|
||||
}
|
||||
}
|
||||
// Waiting list never modifies during a pass,
|
||||
// sorting fails.
|
||||
if (!modified) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Done.
|
||||
this->ops = std::move(sorted);
|
||||
return this->sorted = true;
|
||||
}
|
||||
|
@ -182,8 +167,11 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
|||
// note: behavior may not match running in non-naive mode, and it may
|
||||
// not reproduce the bug
|
||||
for (auto &tensor : tensors) {
|
||||
if (!tensor->isWeight() ||
|
||||
(tensor->isWeight() && !weightAllocated)) {
|
||||
tensor->dataMalloc();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (memPoolSize > 0) {
|
||||
|
|
|
@ -17,8 +17,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -66,8 +65,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -25,8 +25,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
auto &perfEngine = PerfEngine::getInstance();
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType().underlying(), op->getDType()};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
@ -48,8 +47,7 @@ void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
|||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying(),
|
||||
DataType::Float32};
|
||||
auto kernelAttrs = KernelAttrs{device, op->getOpType().underlying()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#include "core/data_type.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include <cstdio>
|
||||
|
||||
__global__ void cudaPrintFloatImpl(float *x, int len) {
|
||||
|
@ -18,4 +20,55 @@ void cudaPrintFloat(float *x, int len) {
|
|||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void cudaPrintTensor(const Tensor &tensor) {
|
||||
cudaPrintFloat(tensor->getRawDataPtr<float *>(), tensor->size());
|
||||
}
|
||||
|
||||
cudnnDataType_t cudnnDataTypeConvert(DataType dataType) {
|
||||
if (dataType == DataType::Float32) {
|
||||
return CUDNN_DATA_FLOAT;
|
||||
}
|
||||
if (dataType == DataType::Double) {
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
}
|
||||
if (dataType == DataType::Float16) {
|
||||
return CUDNN_DATA_HALF;
|
||||
}
|
||||
if (dataType == DataType::Int8) {
|
||||
return CUDNN_DATA_INT8;
|
||||
}
|
||||
if (dataType == DataType::Int32) {
|
||||
return CUDNN_DATA_INT32;
|
||||
}
|
||||
if (dataType == DataType::UInt8) {
|
||||
return CUDNN_DATA_UINT8;
|
||||
}
|
||||
if (dataType == DataType::BFloat16) {
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
}
|
||||
if (dataType == DataType::Int64) {
|
||||
return CUDNN_DATA_INT64;
|
||||
}
|
||||
if (dataType == DataType::Bool) {
|
||||
return CUDNN_DATA_BOOLEAN;
|
||||
}
|
||||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
cudaDataType cublasDataTypeConvert(DataType dataType) {
|
||||
switch (dataType.getIndex()) {
|
||||
case 1:
|
||||
return CUDA_R_32F;
|
||||
// case 3:
|
||||
// return CUDA_R_8I;
|
||||
case 10:
|
||||
return CUDA_R_16F;
|
||||
case 11:
|
||||
return CUDA_R_64F;
|
||||
// case 16:
|
||||
// return CUDA_R_16BF;
|
||||
default:
|
||||
IT_ASSERT(false, "MatMul Unsupported data type");
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -11,6 +11,7 @@ class UnaryCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -50,6 +51,7 @@ class RoundCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -80,6 +82,7 @@ class PReluCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PReluObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -119,6 +122,7 @@ class SoftmaxCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -215,15 +219,12 @@ class SigmoidCnnl : public UnaryCnnl {
|
|||
float getCoef() const override { return 0.0; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, DataType::Float32, ReluCnnl,
|
||||
"Relu_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, DataType::Float32, PReluCnnl,
|
||||
"PRelu_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, DataType::Float32, SigmoidCnnl,
|
||||
"Sigmoid_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, DataType::Float32, RoundCnnl,
|
||||
"Round_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Softmax, DataType::Float32, SoftmaxCnnl,
|
||||
"Softmax_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Relu, ReluCnnl, "Relu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::PRelu, PReluCnnl, "PRelu_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sigmoid, SigmoidCnnl,
|
||||
"Sigmoid_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Round, RoundCnnl, "Round_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Softmax, SoftmaxCnnl,
|
||||
"Softmax_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -10,6 +10,7 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ActivationBackwardObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const yData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -81,11 +82,11 @@ class TanhBackwardCnnl : public ActivationBackwardCnnl {
|
|||
float getCoef() const override { return 0.0; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32,
|
||||
ReluBackwardCnnl, "ReluBackward_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32,
|
||||
SigmoidBackwardCnnl, "SigmoidBackward_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32,
|
||||
TanhBackwardCnnl, "TanhBackward_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, ReluBackwardCnnl,
|
||||
"ReluBackward_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, SigmoidBackwardCnnl,
|
||||
"SigmoidBackward_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, TanhBackwardCnnl,
|
||||
"TanhBackward_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const input = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -101,7 +102,7 @@ class BatchNormCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, DataType::Float32,
|
||||
BatchNormCnnl, "BatchNorm_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BatchNormalization, BatchNormCnnl,
|
||||
"BatchNorm_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -212,7 +212,6 @@ class CastCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl,
|
||||
"Cast_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cast, CastCnnl, "Cast_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class CeilCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -35,7 +36,6 @@ class CeilCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Ceil, DataType::Float32, CeilCnnl,
|
||||
"Ceil_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Ceil, CeilCnnl, "Ceil_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ClipCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ClipObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -30,7 +31,6 @@ class ClipCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Clip, DataType::Float32, ClipCnnl,
|
||||
"Clip_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Clip, ClipCnnl, "Clip_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConcatObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
int num = op->numInputs();
|
||||
int axis = op->getDim();
|
||||
|
@ -50,6 +51,5 @@ class ConcatCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Concat, DataType::Float32, ConcatCnnl,
|
||||
"Concat_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Concat, ConcatCnnl, "Concat_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
@ -151,6 +152,5 @@ class ConvCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Conv, DataType::Float32, ConvCnnl,
|
||||
"Conv_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Conv, ConvCnnl, "Conv_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
@ -76,6 +77,6 @@ class ConvTransCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, DataType::Float32,
|
||||
ConvTransCnnl, "ConvTrans_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvTranspose, ConvTransCnnl,
|
||||
"ConvTrans_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvBackwardFilterObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
@ -154,6 +155,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32,
|
||||
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter,
|
||||
ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class DetCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<DetObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -42,6 +43,5 @@ class DetCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl,
|
||||
"Det_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Det, DetCnnl, "Det_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -11,6 +11,7 @@ class ElementWiseCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -74,6 +75,7 @@ class LogicOpCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -127,6 +129,7 @@ class BitComputeCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -179,6 +182,7 @@ class DivCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -231,6 +235,7 @@ class MaximumCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -282,6 +287,7 @@ class MinimumCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -333,6 +339,7 @@ class MSELossCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<MSELossObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -389,6 +396,7 @@ class PowerCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -442,6 +450,7 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -494,6 +503,7 @@ class FloorModCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -546,6 +556,7 @@ class SquaredDifferenceCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -658,62 +669,48 @@ class BitNotCnnl : public BitComputeCnnl {
|
|||
// CNNL_BLEFT_SHIFT_OP_V2; }
|
||||
// };
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Add, DataType::Float32, AddCnnl,
|
||||
"Add_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sub, DataType::Float32, SubCnnl,
|
||||
"Sub_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
|
||||
"Mul_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Add, AddCnnl, "Add_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sub, SubCnnl, "Sub_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Mul, MulCnnl, "Mul_cnnl_BANG");
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl,
|
||||
"Div_cnnl_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Max, DataType::Float32, MaximumCnnl,
|
||||
"Maximum_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Min, DataType::Float32, MinimumCnnl,
|
||||
"Minimum_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl,
|
||||
"MSELoss_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, PowerCnnl,
|
||||
"Power_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl,
|
||||
"FloorDiv_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::FloorMod, DataType::Float32, FloorModCnnl,
|
||||
"FloorMod_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, DataType::Float32,
|
||||
SquaredDifferenceCnnl, "SquaredDifference_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Equal, DataType::Float32, EqualCnnl,
|
||||
"Equal_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Greater, DataType::Float32,
|
||||
GreaterThanCnnl, "GreaterThan_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, DataType::Float32,
|
||||
GreaterEqualCnnl, "GreaterEqual_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Less, DataType::Float32, LessThanCnnl,
|
||||
"LessThan_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, DataType::Float32,
|
||||
LessEqualCnnl, "LessEqual_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::And, DataType::Float32, AndCnnl,
|
||||
"And_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Or, DataType::Float32, OrCnnl,
|
||||
"Or_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Xor, DataType::Float32, XorCnnl,
|
||||
"Xor_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Not, DataType::Float32, NotCnnl,
|
||||
"Not_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, DataType::Float32, BitAndCnnl,
|
||||
"BitAnd_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, DataType::Float32, BitOrCnnl,
|
||||
"BitOr_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, DataType::Float32, BitXorCnnl,
|
||||
"BitXor_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, DataType::Float32, BitNotCnnl,
|
||||
"BitNot_cnnl_BANG_Float32");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift, DataType::Float32,
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Div, DivCnnl, "Div_cnnl");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Max, MaximumCnnl, "Maximum_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Min, MinimumCnnl, "Minimum_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, MSELossCnnl,
|
||||
"MSELoss_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Pow, PowerCnnl, "Power_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, FloorDivCnnl,
|
||||
"FloorDiv_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::FloorMod, FloorModCnnl,
|
||||
"FloorMod_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::SquaredDifference, SquaredDifferenceCnnl,
|
||||
"SquaredDifference_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Equal, EqualCnnl, "Equal_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Greater, GreaterThanCnnl,
|
||||
"GreaterThan_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::GreaterOrEqual, GreaterEqualCnnl,
|
||||
"GreaterEqual_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Less, LessThanCnnl, "LessThan_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LessOrEqual, LessEqualCnnl,
|
||||
"LessEqual_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::And, AndCnnl, "And_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Or, OrCnnl, "Or_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Xor, XorCnnl, "Xor_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Not, NotCnnl, "Not_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseAnd, BitAndCnnl,
|
||||
"BitAnd_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseOr, BitOrCnnl, "BitOr_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseXor, BitXorCnnl,
|
||||
"BitXor_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::BitwiseNot, BitNotCnnl,
|
||||
"BitNot_cnnl_BANG");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::BitLeftShift,
|
||||
// BitLeftShiftCnnl,
|
||||
// "BitLeftShift_cnnl_BANG_Float32");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift, DataType::Float32,
|
||||
// "BitLeftShift_cnnl_BANG");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::BitRightShift,
|
||||
// BitRightShiftCnnl,
|
||||
// "BitRightShift_cnnl_BANG_Float32");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32,
|
||||
// "BitRightShift_cnnl_BANG");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::Pow,
|
||||
// ElementWiseBang,
|
||||
// "Pow_Bang_Float32");
|
||||
// "Pow_Bang");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ErfCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -36,7 +37,6 @@ class ErfCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Erf, DataType::Float32, ErfCnnl,
|
||||
"Erf_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Erf, ErfCnnl, "Erf_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ExpCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -36,7 +37,6 @@ class ExpCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Exp, DataType::Float32, ExpCnnl,
|
||||
"Exp_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Exp, ExpCnnl, "Exp_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class FillCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<FillObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
@ -29,7 +30,6 @@ class FillCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Fill, DataType::Float32, FillCnnl,
|
||||
"Fill_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Fill, FillCnnl, "Fill_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class FloorCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -35,7 +36,7 @@ class FloorCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Floor, DataType::Float32, FloorCnnl,
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Floor, FloorCnnl,
|
||||
"Floor_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class GatherCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GatherObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -49,7 +50,6 @@ class GatherCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Gather, DataType::Float32, GatherCnnl,
|
||||
"Gather_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Gather, GatherCnnl, "Gather_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<HardtanhObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -30,7 +31,7 @@ class HardtanhCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, DataType::Float32, HardtanhCnnl,
|
||||
"Hardtanh_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Hardtanh, HardtanhCnnl,
|
||||
"Hardtanh_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class L2LossCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<L2LossObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -28,7 +29,6 @@ class L2LossCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::L2Loss, DataType::Float32, L2LossCnnl,
|
||||
"L2Loss_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::L2Loss, L2LossCnnl, "L2Loss_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -8,6 +8,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LayerNormObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -58,7 +59,7 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, DataType::Float32,
|
||||
LayerNormCnnl, "LayerNorm_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LayerNormalization, LayerNormCnnl,
|
||||
"LayerNorm_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class LogCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LogObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -51,7 +52,6 @@ class LogCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Log, DataType::Float32, LogCnnl,
|
||||
"Log_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Log, LogCnnl, "Log_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class LRNCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<LRNObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -56,7 +57,6 @@ class LRNCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LRN, DataType::Float32, LRNCnnl,
|
||||
"LRN_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::LRN, LRNCnnl, "LRN_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -8,6 +8,7 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
auto input_num = op->numInputs();
|
||||
|
@ -107,6 +108,5 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MatMul, DataType::Float32, MatmulCnnl,
|
||||
"Matmul_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MatMul, MatmulCnnl, "Matmul_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -35,7 +36,6 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Neg, DataType::Float32, NegTensorCnnl,
|
||||
"Neg_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Neg, NegTensorCnnl, "Neg_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class PadCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PadObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -57,7 +58,6 @@ class PadCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Pad, DataType::Float32, PadCnnl,
|
||||
"Pad_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Pad, PadCnnl, "Pad_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -8,6 +8,7 @@ class PoolingCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
@ -68,8 +69,8 @@ class avgPoolCnnl : public PoolingCnnl {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, DataType::Float32, maxPoolCnnl,
|
||||
"MaxPool_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AveragePool, DataType::Float32,
|
||||
avgPoolCnnl, "AvgPool_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MaxPool, maxPoolCnnl,
|
||||
"MaxPool_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::AveragePool, avgPoolCnnl,
|
||||
"AvgPool_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -35,7 +36,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, DataType::Float32,
|
||||
ReciprocalCnnl, "Reciprocal_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, ReciprocalCnnl,
|
||||
"Reciprocal_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -9,6 +9,7 @@ class ReduceCnnlBase : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceBaseObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
@ -73,9 +74,9 @@ class ReduceSumCnnl : public ReduceCnnlBase {
|
|||
cnnlReduceOp_t getReduceOp() const override { return CNNL_REDUCE_ADD; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, DataType::Float32,
|
||||
ReduceMeanCnnl, "ReduceMean_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, DataType::Float32,
|
||||
ReduceSumCnnl, "ReduceSum_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReduceMean, ReduceMeanCnnl,
|
||||
"ReduceMean_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::ReduceSum, ReduceSumCnnl,
|
||||
"ReduceSum_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -13,9 +13,9 @@ class CopyBang : public BangKernelWithoutConfig {
|
|||
auto dim = op->getInputs(0)->getDims();
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, dim.size(),
|
||||
dim.data()));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
|
||||
dim.size() * op->getDType().getSize(), dim.data()));
|
||||
cnnlStatus_t stat =
|
||||
cnnlCopy(context->cnnlHandle(), aDesc, inData, aDesc, outData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
|
@ -25,13 +25,8 @@ class CopyBang : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
// reshape/flatten/identity all act as copying from input to output.
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Float32, CopyBang,
|
||||
"Reshape_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Reshape, DataType::Int64, CopyBang,
|
||||
"Reshape_BANG_Int64");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Flatten, DataType::Float32, CopyBang,
|
||||
"Flatten_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Identity, DataType::Float32, CopyBang,
|
||||
"Identity_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Reshape, CopyBang, "Reshape_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Flatten, CopyBang, "Flatten_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Identity, CopyBang, "Identity_BANG");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -36,7 +37,6 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, DataType::Float32, RsqrtCnnl,
|
||||
"Rsqrt_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Rsqrt, RsqrtCnnl, "Rsqrt_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class SplitCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SplitObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
int num = op->numOutputs();
|
||||
int axis = op->getDim();
|
||||
|
@ -49,6 +50,5 @@ class SplitCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Split, DataType::Float32, SplitCnnl,
|
||||
"Split_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Split, SplitCnnl, "Split_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class SqrtCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -36,7 +37,6 @@ class SqrtCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sqrt, DataType::Float32, SqrtCnnl,
|
||||
"Sqrt_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sqrt, SqrtCnnl, "Sqrt_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class TransposeCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<TransposeObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -52,6 +53,7 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<DepthToSpaceObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -101,9 +103,9 @@ class DepthToSpaceCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32,
|
||||
TransposeCnnl, "Transpose_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Transpose, TransposeCnnl,
|
||||
"Transpose_cnnl_BANG");
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DataType::Float32,
|
||||
DepthToSpaceCnnl, "DepthToSpace_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::DepthToSpace, DepthToSpaceCnnl,
|
||||
"DepthToSpace_cnnl_BANG");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -9,6 +9,7 @@ class TrigonCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -150,29 +151,17 @@ class ATanHCnnl : public TrigonCnnl {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sin, DataType::Float32, SinCnnl,
|
||||
"Sin_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cos, DataType::Float32, CosCnnl,
|
||||
"Cos_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Tan, DataType::Float32, TanCnnl,
|
||||
"Tan_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Asin, DataType::Float32, ASinCnnl,
|
||||
"ASin_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Acos, DataType::Float32, ACosCnnl,
|
||||
"ACos_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Atan, DataType::Float32, ATanCnnl,
|
||||
"ATan_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sinh, DataType::Float32, SinHCnnl,
|
||||
"SinH_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cosh, DataType::Float32, CosHCnnl,
|
||||
"CosH_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Tanh, DataType::Float32, TanHCnnl,
|
||||
"TanH_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Asinh, DataType::Float32, ASinHCnnl,
|
||||
"ASinH_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Acosh, DataType::Float32, ACosHCnnl,
|
||||
"ACosH_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Atanh, DataType::Float32, ATanHCnnl,
|
||||
"ATanH_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sin, SinCnnl, "Sin_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cos, CosCnnl, "Cos_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Tan, TanCnnl, "Tan_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Asin, ASinCnnl, "ASin_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Acos, ACosCnnl, "ACos_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Atan, ATanCnnl, "ATan_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Sinh, SinHCnnl, "SinH_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cosh, CosHCnnl, "CosH_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Tanh, TanHCnnl, "TanH_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Asinh, ASinHCnnl, "ASinH_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Acosh, ACosHCnnl, "ACosH_cnnl_BANG");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Atanh, ATanHCnnl, "ATanH_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ class WhereCnnl : public BangKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<WhereObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
|
@ -67,7 +68,6 @@ class WhereCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Where, DataType::Float32, WhereCnnl,
|
||||
"Where_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Where, WhereCnnl, "Where_cnnl_BANG");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NaiveConcat : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<ConcatObj>(_op);
|
||||
auto inputs = op->getInputs(), outputs = op->getOutputs();
|
||||
auto dim = op->getDim();
|
||||
|
@ -41,11 +41,25 @@ template <typename T> class NaiveConcat : public CpuKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::UInt32,
|
||||
NaiveConcat<uint32_t>, "ConcatNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Concat, DataType::Float32,
|
||||
NaiveConcat<float>, "ConcatNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Concat, NaiveConcat, "ConcatNaive_CPU");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NaiveConv : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<ConvObj>(_op);
|
||||
T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *wptr = op->getInputs(1)->getRawDataPtr<T *>();
|
||||
|
@ -50,11 +50,25 @@ template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::UInt32,
|
||||
NaiveConv<uint32_t>, "ConvNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Conv, DataType::Float32, NaiveConv<float>,
|
||||
"ConvNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Conv, NaiveConv, "ConvNaive_CPU");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,10 +3,45 @@
|
|||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
|
||||
virtual T doCompute(T val0, T val1) const = 0;
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NativeElementWise : public CpuKernelWithoutConfig {
|
||||
template <typename T> static T addCompute(T val0, T val1) {
|
||||
return val0 + val1;
|
||||
}
|
||||
|
||||
template <typename T> static T subCompute(T val0, T val1) {
|
||||
return val0 - val1;
|
||||
}
|
||||
|
||||
template <typename T> static T mulCompute(T val0, T val1) {
|
||||
return val0 * val1;
|
||||
}
|
||||
|
||||
template <typename T> static T divCompute(T val0, T val1) {
|
||||
return (T)(val0 / val1);
|
||||
}
|
||||
|
||||
template <typename T> static T equalCompute(T val0, T val1) {
|
||||
return (T)(val0 == val1);
|
||||
}
|
||||
|
||||
template <typename T> static T greaterOrEqualCompute(T val0, T val1) {
|
||||
return (T)(val0 >= val1);
|
||||
}
|
||||
|
||||
template <typename T> static T greaterCompute(T val0, T val1) {
|
||||
return (T)(val0 > val1);
|
||||
}
|
||||
|
||||
template <typename T> static T lessOrEqualCompute(T val0, T val1) {
|
||||
return (T)(val0 <= val1);
|
||||
}
|
||||
|
||||
template <typename T> static T lessCompute(T val0, T val1) {
|
||||
return (T)(val0 < val1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>();
|
||||
|
@ -35,77 +70,77 @@ template <typename T> class NativeElementWise : public CpuKernelWithoutConfig {
|
|||
Shape strideB = getStride(b);
|
||||
|
||||
auto n = op->getOutput()->size();
|
||||
T (*_doCompute)(T val0, T val1);
|
||||
switch (op->getOpType().underlying()) {
|
||||
case OpType::Add:
|
||||
_doCompute = addCompute<T>;
|
||||
break;
|
||||
case OpType::Sub:
|
||||
_doCompute = subCompute<T>;
|
||||
break;
|
||||
case OpType::Mul:
|
||||
_doCompute = mulCompute<T>;
|
||||
break;
|
||||
case OpType::Div:
|
||||
_doCompute = divCompute<T>;
|
||||
break;
|
||||
case OpType::Equal:
|
||||
_doCompute = equalCompute<T>;
|
||||
break;
|
||||
case OpType::GreaterOrEqual:
|
||||
_doCompute = greaterOrEqualCompute<T>;
|
||||
break;
|
||||
case OpType::Greater:
|
||||
_doCompute = greaterCompute<T>;
|
||||
break;
|
||||
case OpType::LessOrEqual:
|
||||
_doCompute = lessOrEqualCompute<T>;
|
||||
break;
|
||||
case OpType::Less:
|
||||
_doCompute = lessCompute<T>;
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
auto shapeIndexC = locate_index(i, shapeC);
|
||||
auto indexA = delocate_index(shapeIndexC, a, strideA);
|
||||
auto indexB = delocate_index(shapeIndexC, b, strideB);
|
||||
outptr[i] = doCompute(inptr0[indexA], inptr1[indexB]);
|
||||
outptr[i] = _doCompute(inptr0[indexA], inptr1[indexB]);
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveAdd : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return val0 + val1; }
|
||||
};
|
||||
template <typename T> class NaiveSub : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return val0 - val1; }
|
||||
};
|
||||
template <typename T> class NaiveMul : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return val0 * val1; }
|
||||
};
|
||||
template <typename T> class NaiveDiv : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); }
|
||||
};
|
||||
template <typename T> class NaiveEqual : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 == val1); }
|
||||
};
|
||||
template <typename T> class NaiveGreaterEqual : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 >= val1); }
|
||||
};
|
||||
template <typename T> class NaiveGreaterThan : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 > val1); }
|
||||
};
|
||||
template <typename T> class NaiveLessEqual : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 <= val1); }
|
||||
};
|
||||
template <typename T> class NaiveLessThan : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 < val1); }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
|
||||
"addNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd<float>,
|
||||
"addNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub<uint32_t>,
|
||||
"subNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub<float>,
|
||||
"subNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul<uint32_t>,
|
||||
"mulNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul<float>,
|
||||
"mulNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
|
||||
"divNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
|
||||
"divNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::UInt32,
|
||||
NaiveEqual<uint32_t>, "equalNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Equal, DataType::Float32,
|
||||
NaiveEqual<float>, "equalNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::UInt32,
|
||||
NaiveGreaterEqual<uint32_t>, "greaterEqualNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, DataType::Float32,
|
||||
NaiveGreaterEqual<float>, "greaterEqualNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::UInt32,
|
||||
NaiveGreaterThan<uint32_t>, "greaterThanNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Greater, DataType::Float32,
|
||||
NaiveGreaterThan<float>, "greaterThanNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::UInt32,
|
||||
NaiveLessEqual<uint32_t>, "lessEqualNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, DataType::Float32,
|
||||
NaiveLessEqual<float>, "lessEqualNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::UInt32,
|
||||
NaiveLessThan<uint32_t>, "lessEqualNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Less, DataType::Float32,
|
||||
NaiveLessThan<float>, "lessEqualNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Add, NativeElementWise, "addNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sub, NativeElementWise, "subNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Mul, NativeElementWise, "mulNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Div, NativeElementWise, "divNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Equal, NativeElementWise,
|
||||
"equalNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::GreaterOrEqual, NativeElementWise,
|
||||
"greaterEqualNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Greater, NativeElementWise,
|
||||
"greaterThanNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::LessOrEqual, NativeElementWise,
|
||||
"lessEqualNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Less, NativeElementWise,
|
||||
"lessEqualNaive_CPU");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NaiveMatmul : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
|
||||
T *A = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
|
@ -23,11 +23,25 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::UInt32,
|
||||
NaiveMatmul<uint32_t>, "MatmulNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MatMul, DataType::Float32,
|
||||
NaiveMatmul<float>, "MatmulNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MatMul, NaiveMatmul, "MatmulNaive_CPU");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -80,8 +80,8 @@ class MemboundInterpreter : public Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MemBound, DataType::UInt32,
|
||||
MemboundInterpreter, "MemboundInterpreter_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MemBound, MemboundInterpreter,
|
||||
"MemboundInterpreter_CPU");
|
||||
|
||||
} // namespace infini
|
||||
|
||||
|
|
|
@ -2,42 +2,10 @@
|
|||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativePooling : public CpuKernelWithoutConfig {
|
||||
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih,
|
||||
int iw, T *inptr) const = 0;
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
if (dh != 1 || dw != 1)
|
||||
IT_TODO_HALT(); // To support dailated pooling
|
||||
auto outDim = op->getOutput()->getDims();
|
||||
int oh = outDim[2], ow = outDim[3];
|
||||
for (auto i = 0; i < n; i++) {
|
||||
for (auto j = 0; j < c; j++) {
|
||||
auto inoffset = i * (c * ih * iw) + j * ih * iw;
|
||||
for (auto h = 0; h < oh; h++) {
|
||||
for (auto w = 0; w < ow; w++) {
|
||||
// TODO: verify ceil mode
|
||||
T val =
|
||||
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
|
||||
ih, iw, inptr + inoffset);
|
||||
auto outoffset =
|
||||
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
|
||||
outptr[outoffset] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveMaxPool : public NativePooling<T> {
|
||||
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
|
||||
T *inptr) const override {
|
||||
class NativePooling : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
static T getMaxPoolingValue(int kh, int kw, int posh, int posw, int ih,
|
||||
int iw, T *inptr) {
|
||||
T maxval = 0;
|
||||
for (auto k = 0; k < kh; k++) {
|
||||
for (auto l = 0; l < kw; l++) {
|
||||
|
@ -53,11 +21,10 @@ template <typename T> class NaiveMaxPool : public NativePooling<T> {
|
|||
}
|
||||
return maxval;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveAvgPool : public NativePooling<T> {
|
||||
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
|
||||
T *inptr) const override {
|
||||
template <typename T>
|
||||
static T getAvgPoolingValue(int kh, int kw, int posh, int posw, int ih,
|
||||
int iw, T *inptr) {
|
||||
T sum = 0;
|
||||
for (auto k = 0; k < kh; k++) {
|
||||
for (auto l = 0; l < kw; l++) {
|
||||
|
@ -71,12 +38,70 @@ template <typename T> class NaiveAvgPool : public NativePooling<T> {
|
|||
}
|
||||
return T(sum / (kh * kw));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
||||
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
if (dh != 1 || dw != 1)
|
||||
IT_TODO_HALT(); // To support dailated pooling
|
||||
auto outDim = op->getOutput()->getDims();
|
||||
int oh = outDim[2], ow = outDim[3];
|
||||
|
||||
T(*_doCompute)
|
||||
(int kh, int kw, int posh, int posw, int ih, int iw, T *inptr);
|
||||
switch (op->getOpType().underlying()) {
|
||||
case OpType::MaxPool:
|
||||
_doCompute = getMaxPoolingValue<T>;
|
||||
break;
|
||||
case OpType::AveragePool:
|
||||
_doCompute = getAvgPoolingValue<T>;
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
for (auto i = 0; i < n; i++) {
|
||||
for (auto j = 0; j < c; j++) {
|
||||
auto inoffset = i * (c * ih * iw) + j * ih * iw;
|
||||
for (auto h = 0; h < oh; h++) {
|
||||
for (auto w = 0; w < ow; w++) {
|
||||
// TODO: verify ceil mode
|
||||
T val = _doCompute(kh, kw, h * sh - ph, w * sw - pw, ih,
|
||||
iw, inptr + inoffset);
|
||||
auto outoffset =
|
||||
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
|
||||
outptr[outoffset] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32,
|
||||
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32,
|
||||
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, DataType::Float32,
|
||||
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, NativePooling,
|
||||
"maxPoolNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::AveragePool, NativePooling,
|
||||
"avgPoolNaive_CPU");
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NaiveSplit : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<SplitObj>(_op);
|
||||
auto inputs = op->getInputs(), outputs = op->getOutputs();
|
||||
auto dim = op->getDim();
|
||||
|
@ -40,11 +40,24 @@ template <typename T> class NaiveSplit : public CpuKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::UInt32,
|
||||
NaiveSplit<uint32_t>, "SplitNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Split, DataType::Float32,
|
||||
NaiveSplit<float>, "SplitNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Split, NaiveSplit, "SplitNaive_CPU");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -14,9 +14,9 @@ inline Shape idx2Pos(const Shape &shape, size_t idx) {
|
|||
return pos;
|
||||
}
|
||||
|
||||
template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NaiveTranspose : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<TransposeObj>(_op);
|
||||
auto inputs = op->getInputs(), outputs = op->getOutputs();
|
||||
const auto &inDim = inputs[0]->getDims();
|
||||
|
@ -35,11 +35,26 @@ template <typename T> class NaiveTranspose : public CpuKernelWithoutConfig {
|
|||
outPtr[outIdx] = inPtr[inIdx];
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::UInt32,
|
||||
NaiveTranspose<uint32_t>, "TransposeNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Transpose, DataType::Float32,
|
||||
NaiveTranspose<float>, "TransposeNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Transpose, NaiveTranspose,
|
||||
"TransposeNaive_CPU");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -4,25 +4,170 @@
|
|||
#include "operators/softmax.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativeUnary : public CpuKernelWithoutConfig {
|
||||
virtual T doCompute(T val) const = 0;
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NativeUnary : public CpuKernelWithoutConfig {
|
||||
template <typename T> static T reluCompute(T val) {
|
||||
return std::max(T(0), val);
|
||||
}
|
||||
|
||||
template <typename T> static T sigmoidCompute(T val) {
|
||||
return 1 / (1 + pow(E_CONSTANT, -val));
|
||||
}
|
||||
|
||||
template <typename T> static T hardSigmoidCompute(T val) {
|
||||
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
|
||||
}
|
||||
|
||||
template <typename T> static T hardSwishCompute(T val) {
|
||||
return val *
|
||||
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
|
||||
}
|
||||
|
||||
template <typename T> static T tanhCompute(T val) {
|
||||
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
|
||||
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
|
||||
}
|
||||
|
||||
template <typename T> static T absCompute(T val) {
|
||||
return val < 0 ? -val : val;
|
||||
}
|
||||
|
||||
template <typename T> static T sqrtCompute(T val) { return std::sqrt(val); }
|
||||
|
||||
template <typename T> static T cosCompute(T val) { return std::cos(val); }
|
||||
|
||||
template <typename T> static T sinCompute(T val) { return std::sin(val); }
|
||||
|
||||
template <typename T> static T tanCompute(T val) { return std::tan(val); }
|
||||
|
||||
template <typename T> static T sinhCompute(T val) { return std::sinh(val); }
|
||||
|
||||
template <typename T> static T coshCompute(T val) { return std::cosh(val); }
|
||||
|
||||
template <typename T> static T geluCompute(T val) {
|
||||
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
|
||||
}
|
||||
|
||||
template <typename T> static T erfCompute(T val) { return std::erf(val); }
|
||||
|
||||
template <typename T> static T aCosCompute(T val) { return std::acos(val); }
|
||||
|
||||
template <typename T> static T aCoshCompute(T val) {
|
||||
return std::acosh(val);
|
||||
}
|
||||
|
||||
template <typename T> static T aSinCompute(T val) { return std::asin(val); }
|
||||
|
||||
template <typename T> static T aSinhCompute(T val) {
|
||||
return std::asinh(val);
|
||||
}
|
||||
template <typename T> static T aTanCompute(T val) { return std::atan(val); }
|
||||
|
||||
template <typename T> static T aTanhCompute(T val) {
|
||||
return std::atanh(val);
|
||||
}
|
||||
template <typename T> static T negCompute(T val) { return -val; }
|
||||
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
||||
auto outDim = op->getOutput()->getDims();
|
||||
auto n = op->getOutput()->size();
|
||||
|
||||
T (*_doCompute)(T val);
|
||||
switch (op->getOpType().underlying()) {
|
||||
case OpType::Relu:
|
||||
_doCompute = reluCompute<T>;
|
||||
break;
|
||||
case OpType::Gelu:
|
||||
_doCompute = geluCompute<T>;
|
||||
break;
|
||||
case OpType::Sigmoid:
|
||||
_doCompute = sigmoidCompute<T>;
|
||||
break;
|
||||
case OpType::HardSigmoid:
|
||||
_doCompute = hardSigmoidCompute<T>;
|
||||
break;
|
||||
case OpType::HardSwish:
|
||||
_doCompute = hardSwishCompute<T>;
|
||||
break;
|
||||
case OpType::Tanh:
|
||||
_doCompute = tanhCompute<T>;
|
||||
break;
|
||||
case OpType::Abs:
|
||||
_doCompute = absCompute<T>;
|
||||
break;
|
||||
case OpType::Sqrt:
|
||||
_doCompute = sqrtCompute<T>;
|
||||
break;
|
||||
case OpType::Erf:
|
||||
_doCompute = erfCompute<T>;
|
||||
break;
|
||||
case OpType::Neg:
|
||||
_doCompute = negCompute<T>;
|
||||
break;
|
||||
case OpType::Cos:
|
||||
_doCompute = cosCompute<T>;
|
||||
break;
|
||||
case OpType::Sin:
|
||||
_doCompute = sinCompute<T>;
|
||||
break;
|
||||
case OpType::Tan:
|
||||
_doCompute = tanCompute<T>;
|
||||
break;
|
||||
case OpType::Sinh:
|
||||
_doCompute = sinhCompute<T>;
|
||||
break;
|
||||
case OpType::Cosh:
|
||||
_doCompute = coshCompute<T>;
|
||||
break;
|
||||
case OpType::Acos:
|
||||
_doCompute = aCosCompute<T>;
|
||||
break;
|
||||
case OpType::Asin:
|
||||
_doCompute = aSinCompute<T>;
|
||||
break;
|
||||
case OpType::Asinh:
|
||||
_doCompute = aSinhCompute<T>;
|
||||
break;
|
||||
case OpType::Atan:
|
||||
_doCompute = aTanCompute<T>;
|
||||
break;
|
||||
case OpType::Atanh:
|
||||
_doCompute = aTanhCompute<T>;
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
outptr[offset] = doCompute(inptr[offset]);
|
||||
outptr[offset] = _doCompute(inptr[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
class NaiveSoftmax : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
@ -37,98 +182,28 @@ template <typename T> class NaiveSoftmax : public CpuKernelWithoutConfig {
|
|||
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveRelu : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::max(T(0), val); }
|
||||
};
|
||||
template <typename T> class NaiveSigmoid : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return 1 / (1 + pow(E_CONSTANT, -val));
|
||||
}
|
||||
};
|
||||
template <typename T> class NaiveHardSigmoid : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
|
||||
}
|
||||
};
|
||||
template <typename T> class NaiveHardSwish : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return val *
|
||||
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
|
||||
}
|
||||
};
|
||||
template <typename T> class NaiveTanh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
|
||||
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
|
||||
}
|
||||
};
|
||||
template <typename T> class NaiveAbs : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return val < 0 ? -val : val; }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSqrt : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::sqrt(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveCos : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::cos(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSin : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::sin(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveTan : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::tan(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSinh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::sinh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveCosh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::cosh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveGelu : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveErf : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::erf(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveACos : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::acos(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveACosh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::acosh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveASin : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::asin(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveASinh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::asinh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveATanh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::atanh(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveNeg : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return -val; }
|
||||
};
|
||||
|
||||
template <typename T> class Clip : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class Clip : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<ClipObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
@ -143,11 +218,28 @@ template <typename T> class Clip : public CpuKernelWithoutConfig {
|
|||
: val;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class Log : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class Log : public CpuKernelWithoutConfig {
|
||||
template <typename T>
|
||||
void doCompute(const Operator &_op, const RuntimeObj *context) const {
|
||||
auto op = as<LogObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
@ -176,70 +268,50 @@ template <typename T> class Log : public CpuKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
doCompute<DT<N>::t>(_op, context)
|
||||
|
||||
int dataTypeIdx = _op->getDType().getIndex();
|
||||
switch (dataTypeIdx) {
|
||||
CASE(1); // DataType::Float32
|
||||
break;
|
||||
CASE(12); // DataType::UInt32
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveATan : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::atan(val); }
|
||||
};
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary,
|
||||
"hardSigmoidNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, NativeUnary,
|
||||
"hardSwishNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, NativeUnary, "tanhNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Abs, NativeUnary, "absNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, NativeUnary, "sqrtNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Erf, NativeUnary, "erfNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Neg, NativeUnary, "negNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Cos, NativeUnary, "Cos_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sin, NativeUnary, "Sin_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tan, NativeUnary, "Tan_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sinh, NativeUnary, "Sinh_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Cosh, NativeUnary, "Cosh_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Acos, NativeUnary, "ACos_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Acosh, NativeUnary, "ACosh_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Asin, NativeUnary, "ASin_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Asinh, NativeUnary, "ASinh_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Atan, NativeUnary, "Atan_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Atanh, NativeUnary, "ATanh_CPU");
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
|
||||
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
|
||||
"reluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu<float>,
|
||||
"geluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu<float>,
|
||||
"geluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
|
||||
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
|
||||
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32,
|
||||
NaiveHardSigmoid<float>, "hardSigmoidNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32,
|
||||
NaiveHardSwish<float>, "hardSwishNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
|
||||
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
|
||||
"tanhNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
|
||||
"absNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
|
||||
"absNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
|
||||
"sqrtNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
|
||||
"erfNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Neg, DataType::Float32, NaiveNeg<float>,
|
||||
"negNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
|
||||
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
||||
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Clip, DataType::Float32, Clip<float>,
|
||||
"Clip_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Atan, DataType::Float32, NaiveATan<float>,
|
||||
"Atan_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Log, DataType::Float32, Log<float>,
|
||||
"Log_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Cos, DataType::Float32, NaiveCos<float>,
|
||||
"Cos_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sin, DataType::Float32, NaiveSin<float>,
|
||||
"Sin_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tan, DataType::Float32, NaiveTan<float>,
|
||||
"Tan_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sinh, DataType::Float32, NaiveSinh<float>,
|
||||
"Sinh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Cosh, DataType::Float32, NaiveCosh<float>,
|
||||
"Cosh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Acos, DataType::Float32, NaiveACos<float>,
|
||||
"ACos_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Acosh, DataType::Float32,
|
||||
NaiveACosh<float>, "ACosh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Asin, DataType::Float32, NaiveASin<float>,
|
||||
"ASin_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Asinh, DataType::Float32,
|
||||
NaiveASinh<float>, "ASinh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Atanh, DataType::Float32,
|
||||
NaiveATanh<float>, "ATanh_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, NaiveSoftmax, "softmaxNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Clip, Clip, "Clip_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Log, Log, "Log_CPU");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -48,13 +48,13 @@ class G2BMMCudnn : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<G2BMMObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = g2bmmKernel(op, context);
|
||||
IT_ASSERT(success);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, DataType::Float32, G2BMMCudnn,
|
||||
"G2BMM_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, G2BMMCudnn, "G2BMM_cuDNN_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -49,13 +49,13 @@ class GBMMCudnn : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GBMMObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = gbmmKernel(op, context);
|
||||
IT_ASSERT(success);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn,
|
||||
"GBMM_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, GBMMCudnn, "GBMM_cuDNN_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -39,8 +39,8 @@ class AllGatherNCCL : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllGather, DataType::Float32,
|
||||
AllGatherNCCL, "AllGather_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllGather, AllGatherNCCL,
|
||||
"AllGather_NCCL_CUDA");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -13,15 +13,24 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
|||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
void *input = op->getInputs(0)->getRawDataPtr<void *>();
|
||||
void *output = op->getOutput()->getRawDataPtr<void *>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
ncclDataType_t ncclType = ncclFloat;
|
||||
if (op->getDType() == DataType::Float16) {
|
||||
ncclType = ncclFloat16;
|
||||
} else if (op->getDType() == DataType::Int8) {
|
||||
ncclType = ncclInt8;
|
||||
} else if (op->getDType() == DataType::Float32) {
|
||||
ncclType = ncclFloat;
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
size_t count = op->getInputs(0)->size();
|
||||
|
||||
ncclComm_t comm =
|
||||
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getNcclComm();
|
||||
// TODO: Using default stream 0 for now.
|
||||
checkNcclError(ncclAllReduce(input, output, count, ncclFloat,
|
||||
getRedOp(), comm, 0));
|
||||
checkNcclError(
|
||||
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
|
||||
}
|
||||
|
||||
virtual ncclRedOp_t getRedOp() const = 0;
|
||||
|
@ -43,16 +52,16 @@ class AllReduceAvgNCCL : public AllReduceNCCL {
|
|||
ncclRedOp_t getRedOp() const override { return ncclAvg; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, DataType::Float32,
|
||||
AllReduceSumNCCL, "AllReduce_Sum_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, DataType::Float32,
|
||||
AllReduceProdNCCL, "AllReduce_Prod_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, DataType::Float32,
|
||||
AllReduceMinNCCL, "AllReduce_Min_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, DataType::Float32,
|
||||
AllReduceMaxNCCL, "AllReduce_Max_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, DataType::Float32,
|
||||
AllReduceAvgNCCL, "AllReduce_Avg_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceSum, AllReduceSumNCCL,
|
||||
"AllReduce_Sum_NCCL_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceProd, AllReduceProdNCCL,
|
||||
"AllReduce_Prod_NCCL_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMin, AllReduceMinNCCL,
|
||||
"AllReduce_Min_NCCL_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceMax, AllReduceMaxNCCL,
|
||||
"AllReduce_Max_NCCL_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AllReduceAvg, AllReduceAvgNCCL,
|
||||
"AllReduce_Avg_NCCL_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
||||
|
|
|
@ -40,6 +40,7 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
|||
public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
IT_ASSERT(_op->getDType() == DataType::Float32);
|
||||
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||
_op->getInputs()[2], _op->getInputs()[3],
|
||||
_op->getInputs()[4], _op->getInputs()[5],
|
||||
|
@ -47,6 +48,6 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
|
||||
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, AttentionKVCacheCuda,
|
||||
"AttentionKVCache_CUDA");
|
||||
} // namespace infini
|
||||
|
|
|
@ -10,6 +10,7 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
|||
auto op = as<BatchNormObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
cudnnStatus_t stat;
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
void *const meanData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
|
@ -59,6 +60,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, DataType::Float32,
|
||||
BatchNormCudnn, "BatchNorm_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::BatchNormalization, BatchNormCudnn,
|
||||
"BatchNorm_cuDNN_CUDA");
|
||||
} // namespace infini
|
||||
|
|
|
@ -25,8 +25,8 @@ class BroadcastNCCL : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, DataType::Float32,
|
||||
BroadcastNCCL, "Broadcast_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Broadcast, BroadcastNCCL,
|
||||
"Broadcast_NCCL_CUDA");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -9,7 +9,7 @@ class ClipCuda : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ClipObj>(_op);
|
||||
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto min = op->getMin();
|
||||
|
@ -21,7 +21,6 @@ class ClipCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Clip, DataType::Float32, ClipCuda,
|
||||
"Clip_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Clip, ClipCuda, "Clip_CUDA");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
#include "operators/conv.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
|
||||
namespace infini {
|
||||
|
||||
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
|
@ -56,8 +58,11 @@ class convCudnn : public Kernel {
|
|||
const ConvCuDnnPerfRecord &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
if (op->getInputs().size() > 2) // Bias is not supported yet
|
||||
// Bias is not supported yet
|
||||
if (op->getInputs().size() > 2) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
|
||||
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
|
@ -72,27 +77,26 @@ class convCudnn : public Kernel {
|
|||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
|
||||
inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, n, channels, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||
CUDNN_TENSOR_NCHW, f,
|
||||
channelsPerGrp, r, s));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(
|
||||
knDesc, cudnnDataType, CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
|
||||
// get bias
|
||||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
|
||||
cudnnDataType, 1, f, 1, 1));
|
||||
|
||||
// get convlution descriptor
|
||||
// get convolution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
||||
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
|
||||
CUDNN_DATA_FLOAT));
|
||||
cudnnDataType));
|
||||
if (g > 1) {
|
||||
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
||||
}
|
||||
|
@ -120,14 +124,14 @@ class convCudnn : public Kernel {
|
|||
assert(false);
|
||||
}
|
||||
|
||||
// get output descriptor
|
||||
int outn, outc, outh, outw;
|
||||
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
|
||||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, outn, outc,
|
||||
outh, outw));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, outn, outc, outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
|
@ -151,55 +155,9 @@ class convCudnn : public Kernel {
|
|||
inData, knDesc, knData, convDesc,
|
||||
ALGOS[record->algo], wsData, wsSize,
|
||||
&beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
// TODO:
|
||||
// // bias
|
||||
// if (bias != nullptr) {
|
||||
// auto sz = op.getOutputs()[0]->size();
|
||||
// // TODO: element wise
|
||||
// t += sz * 2 / 400;
|
||||
// }
|
||||
// // act
|
||||
// if (act != None) {
|
||||
// stat = cudnnActivationForward(cudnnHandle(), actDesc,
|
||||
// &alpha, inDesc, inData,
|
||||
// &beta, outDesc, outData);
|
||||
// checkCudaError(cudaDeviceSynchronize());
|
||||
// end = ch::high_resolution_clock::now();
|
||||
// if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
// durtime = INFINITY;
|
||||
// break;
|
||||
// }
|
||||
// t +=
|
||||
// ch::duration_cast<ch::duration<double>>(end -
|
||||
// beg).count() * 1000; // ms
|
||||
// }
|
||||
|
||||
// best = ConvResult{durtime, ALGOS[i], wsSize, false};
|
||||
|
||||
// // w/ bias & act
|
||||
// for (int j = 0; j < rounds + warmupRounds; ++j) {
|
||||
// cudnnStatus_t stat;
|
||||
// if (j == warmupRounds) {
|
||||
// checkCudaError(cudaDeviceSynchronize());
|
||||
// beg = ch::high_resolution_clock::now();
|
||||
// }
|
||||
// stat = cudnnConvolutionBiasActivationForward(
|
||||
// cudnnHandle(), &alpha, inDesc, inData, knDesc, knData,
|
||||
// convDesc, ALGOS[i], wsData, wsSize, &beta, outDesc,
|
||||
// outData, biasDesc, biasData, actDesc, outDesc, outData);
|
||||
// if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
// // checkCudnnError(stat);
|
||||
// // Do not checkCudnnError since not all algorithms are
|
||||
// // supported
|
||||
// durtime_fuse = INFINITY;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
}
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
|
@ -238,10 +196,12 @@ class convCudnn : public Kernel {
|
|||
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
|
||||
ALGOS[record.algo], &record.workspaceSize);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
continue;
|
||||
if (record.workspaceSize > context->getWorkspaceSize())
|
||||
}
|
||||
if (record.workspaceSize > context->getWorkspaceSize()) {
|
||||
continue;
|
||||
}
|
||||
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
|
@ -249,8 +209,9 @@ class convCudnn : public Kernel {
|
|||
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||
knData, convDesc, ALGOS[record.algo], wsData,
|
||||
record.workspaceSize, &beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
continue;
|
||||
}
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
|
||||
|
@ -263,8 +224,9 @@ class convCudnn : public Kernel {
|
|||
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
if (ret.time > record.time) {
|
||||
ret = record;
|
||||
}
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
|
@ -291,8 +253,7 @@ class convCudnn : public Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn,
|
||||
"Conv_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, convCudnn, "Conv_cuDNN_CUDA");
|
||||
|
||||
REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json);
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,261 +0,0 @@
|
|||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
|
||||
namespace infini {
|
||||
|
||||
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
||||
int mode = 1;
|
||||
size_t workspaceSize = 100000;
|
||||
bool fuseAct = false;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 1;
|
||||
j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize);
|
||||
}
|
||||
static PerfRecord from_json(const json &j) {
|
||||
ConvCuDnnPerfRecordObj tmp;
|
||||
auto [Algo, Mode, FuseAct, Time, WorkspaceSize] =
|
||||
j["data"].get<tuple<int, int, bool, double, size_t>>();
|
||||
tmp.algo = Algo;
|
||||
tmp.mode = Mode;
|
||||
tmp.fuseAct = FuseAct;
|
||||
tmp.time = Time;
|
||||
tmp.workspaceSize = WorkspaceSize;
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
|
||||
|
||||
class convCudnnFP16 : public Kernel {
|
||||
|
||||
static constexpr int N_ALGO = 8;
|
||||
static constexpr int N_MODE = 2;
|
||||
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
|
||||
|
||||
static constexpr cudnnConvolutionMode_t MODES[2] = {
|
||||
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||
|
||||
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
|
||||
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
cudnnTensorDescriptor_t>
|
||||
createCuDNNDescriptor(const Ref<ConvObj> &op,
|
||||
const ConvCuDnnPerfRecord &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
// Bias is not supported yet
|
||||
if (op->getInputs().size() > 2) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
const int g = c / cpg;
|
||||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
|
||||
int channelsPerGrp = cpg, channels = c;
|
||||
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(inDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_HALF, n, channels,
|
||||
h, w)); /*fp16 type*/
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(
|
||||
knDesc, CUDNN_DATA_HALF, /*fp16 type*/
|
||||
CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s));
|
||||
// get bias
|
||||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(biasDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_HALF, 1, f, 1,
|
||||
1)); /*fp16 type*/
|
||||
|
||||
// get convolution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||
// TODO: CUDNN_CONVOLUTION is a tunable argument
|
||||
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||
convDesc, ph, pw, sh, sw, dh, dw, MODES[record->mode],
|
||||
CUDNN_DATA_HALF)); /*fp16 type*/
|
||||
if (g > 1) {
|
||||
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
|
||||
}
|
||||
|
||||
// get activation descriptor
|
||||
cudnnActivationDescriptor_t actDesc;
|
||||
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
|
||||
// NOT_PROPAGATE_NAN is requierd by
|
||||
// cudnnConvolotionBiasActivationForward
|
||||
switch (op->getAct()) {
|
||||
case ActType::Relu:
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
case ActType::Sigmoid:
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
case ActType::None:
|
||||
checkCudnnError(
|
||||
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
|
||||
CUDNN_NOT_PROPAGATE_NAN, 0));
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// get output descriptor
|
||||
int outn, outc, outh, outw;
|
||||
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
|
||||
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_HALF, outn, outc,
|
||||
outh, outw));
|
||||
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||
op->getOutput()->getDims(),
|
||||
"cuDNN output shape mismatches with OP output shape");
|
||||
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
}
|
||||
|
||||
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
size_t wsSize = record->workspaceSize;
|
||||
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
|
||||
inData, knDesc, knData, convDesc,
|
||||
ALGOS[record->algo], wsData, wsSize,
|
||||
&beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
|
||||
// default ctor
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
|
||||
auto &record = *recordRef;
|
||||
record.mode = mode;
|
||||
record.algo = algo;
|
||||
cudnnStatus_t stat;
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, recordRef);
|
||||
|
||||
// get workspace
|
||||
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
|
||||
ALGOS[record.algo], &record.workspaceSize);
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
continue;
|
||||
}
|
||||
if (record.workspaceSize > context->getWorkspaceSize()) {
|
||||
continue;
|
||||
}
|
||||
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionForward(
|
||||
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||
knData, convDesc, ALGOS[record.algo], wsData,
|
||||
record.workspaceSize, &beta, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS) {
|
||||
continue;
|
||||
}
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
|
||||
inDesc, inData, knDesc, knData,
|
||||
convDesc, ALGOS[record.algo],
|
||||
wsData, record.workspaceSize,
|
||||
&beta, outDesc, outData);
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time) {
|
||||
ret = record;
|
||||
}
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
}
|
||||
}
|
||||
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
|
||||
// ret.mode);
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto record = as<ConvCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, record, context);
|
||||
IT_ASSERT(success);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float16, convCudnnFP16,
|
||||
"Conv_cuDNN_CUDA_Float16");
|
||||
|
||||
} // namespace infini
|
|
@ -219,6 +219,7 @@ class convBackwardDataCudnn : public Kernel {
|
|||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
// with paramters in default ctor
|
||||
auto record = make_ref<ConvTransposedCuDnnPerfRecordObj>();
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
|
@ -300,8 +301,9 @@ class convBackwardDataCudnn : public Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, DataType::Float32,
|
||||
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32,
|
||||
convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTranspose, convBackwardDataCudnn,
|
||||
"ConvTranposed_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, convBackwardDataCudnn,
|
||||
"ConvTranposedNHWC_cuDNN_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "cuda/cuda_element_wise.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
|
||||
namespace infini {
|
||||
class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
||||
|
@ -44,22 +45,21 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
|||
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
|
||||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||
|
||||
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, a[0], a[1],
|
||||
a[2], a[3]));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
aDesc, CUDNN_TENSOR_NCHW, cudnnDataType, a[0], a[1], a[2], a[3]));
|
||||
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&bDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(bDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, b[0], b[1],
|
||||
b[2], b[3]));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
bDesc, CUDNN_TENSOR_NCHW, cudnnDataType, b[0], b[1], b[2], b[3]));
|
||||
|
||||
// get outputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&cDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(cDesc, CUDNN_TENSOR_NCHW,
|
||||
CUDNN_DATA_FLOAT, c[0], c[1],
|
||||
c[2], c[3]));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
cDesc, CUDNN_TENSOR_NCHW, cudnnDataType, c[0], c[1], c[2], c[3]));
|
||||
|
||||
// get op descriptor
|
||||
cudnnOpTensorDescriptor_t opDesc;
|
||||
|
@ -127,40 +127,33 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||
|
||||
if (op->getOpType() == OpType::Div)
|
||||
div_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
else if (op->getOpType() == OpType::Pow)
|
||||
pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
else if (op->getOpType() == OpType::Add) {
|
||||
add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
const int dType = _op->getDType().getIndex();
|
||||
if (op->getOpType() == OpType::Div) {
|
||||
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
|
||||
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else if (op->getOpType() == OpType::Add) {
|
||||
add_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
|
||||
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else if (op->getOpType() == OpType::Pow) {
|
||||
pow_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
|
||||
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else if (op->getOpType() == OpType::Less) {
|
||||
less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else
|
||||
less_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3],
|
||||
b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn,
|
||||
"Add_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn,
|
||||
"Sub_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn,
|
||||
"Mul_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Min, DataType::Float32, MinCudnn,
|
||||
"Min_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn,
|
||||
"Max_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Add, AddCudnn, "Add_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sub, SubCudnn, "Sub_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Mul, MulCudnn, "Mul_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Min, MinCudnn, "Min_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Max, MaxCudnn, "Max_cuDNN_CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Div, ElementWiseCuda, "Div_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pow, ElementWiseCuda, "Pow_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Less, ElementWiseCuda, "Less_CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
|
||||
"Div_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda,
|
||||
"Add_CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
|
||||
"Pow__CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda,
|
||||
"Less__CUDA_Int64");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include <math.h>
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
|
@ -129,44 +130,113 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
|||
}
|
||||
}
|
||||
|
||||
#define CASE(OP, T) \
|
||||
_##OP##_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
|
||||
#define SWITCH_DTYPE(OP, DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(OP, 1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(OP, 2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(OP, 3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(OP, 4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(OP, 5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(OP, 6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(OP, 7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(OP, 10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(OP, 11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(OP, 12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(OP, 13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(OP, 16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_div_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
|
||||
b2, b3, c0, c1, c2, c3);
|
||||
SWITCH_DTYPE(div, dType)
|
||||
}
|
||||
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
void add_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_add_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
SWITCH_DTYPE(add, dType)
|
||||
}
|
||||
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
|
||||
b2, b3, c0, c1, c2, c3);
|
||||
if (dType == 1) {
|
||||
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 3) {
|
||||
_pow_kernel<int8_t><<<gridsize, blocksize>>>(
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 10) {
|
||||
int a_size = a0 * a1 * a2 * a3;
|
||||
int b_size = b0 * b1 * b2 * b3;
|
||||
int c_size = c0 * c1 * c2 * c3;
|
||||
vector<float> a_float(a_size);
|
||||
vector<float> b_float(b_size);
|
||||
vector<float> c_float(c_size);
|
||||
for (int i = 0; i < a_size; ++i) {
|
||||
a_float[i] = __half2float(((half *)a)[i]);
|
||||
}
|
||||
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
for (int i = 0; i < b_size; ++i) {
|
||||
b_float[i] = __half2float(((half *)b)[i]);
|
||||
}
|
||||
_pow_kernel<float><<<gridsize, blocksize>>>(
|
||||
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
for (int i = 0; i < c_size; ++i) {
|
||||
((half *)c)[i] = __float2half(c_float[i]);
|
||||
}
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
void less_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_less_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
SWITCH_DTYPE(less, dType)
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -25,12 +25,12 @@ class ExpandCuda : public CudaKernelWithoutConfig {
|
|||
inputShape.data[i] = in_Shape[i];
|
||||
outputsize *= out_Shape[i];
|
||||
}
|
||||
expandKernel((float *)inputData, (float *)outputData, nDims, outputsize,
|
||||
const int dType = op->getDType().getIndex();
|
||||
expandKernel(dType, inputData, outputData, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda,
|
||||
"Expand_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Expand, ExpandCuda, "Expand_CUDA");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _expandKernel(float *input, float *output, int nDims,
|
||||
template <class T>
|
||||
__global__ void _expandKernel(void *input, void *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
|
@ -33,17 +35,64 @@ __global__ void _expandKernel(float *input, float *output, int nDims,
|
|||
temp *= inputShape.data[i];
|
||||
v = v / outputShape.data[i];
|
||||
}
|
||||
output[outputIdx] = input[inputIdx];
|
||||
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void expandKernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape) {
|
||||
|
||||
#define CASE(T) \
|
||||
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
input, output, nDims, outputsize, inputShape, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
void expandKernel(int dType, void *input, void *output, int nDims,
|
||||
int outputsize, SmallArray inputShape,
|
||||
SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||
_expandKernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -8,6 +8,7 @@ class ExtendCuda : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ExtendObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto inData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
auto outData = op->getOutputs()[0]->getRawDataPtr<float *>();
|
||||
int blockSize = 1;
|
||||
|
@ -22,6 +23,5 @@ class ExtendCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Extend, DataType::Float32, ExtendCuda,
|
||||
"Extend_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Extend, ExtendCuda, "Extend_CUDA");
|
||||
} // namespace infini
|
||||
|
|
|
@ -15,12 +15,23 @@ class GatherCuda : public CudaKernelWithoutConfig {
|
|||
GatherMetaData metaData;
|
||||
initGatherMetaData(metaData, op);
|
||||
|
||||
auto inData = input->getRawDataPtr<float *>();
|
||||
auto outData = op->getOutput()->getRawDataPtr<float *>();
|
||||
gather_kernel(inData, outData, metaData, op->getOutput()->size());
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
if (op->getDType() == DataType::Float32) {
|
||||
gather_kernel<float>((float *)inputData, (float *)outputData,
|
||||
metaData, op->getOutput()->size());
|
||||
} else if (op->getDType() == DataType::Float16) {
|
||||
gather_kernel<half>((half *)inputData, (half *)outputData, metaData,
|
||||
op->getOutput()->size());
|
||||
} else if (op->getDType() == DataType::Int8) {
|
||||
gather_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData,
|
||||
metaData, op->getOutput()->size());
|
||||
} else {
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Gather, DataType::Float32, GatherCuda,
|
||||
"Gather_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Gather, GatherCuda, "Gather_CUDA");
|
||||
} // namespace infini
|
||||
|
|
|
@ -28,27 +28,32 @@ __device__ T gatheredOffset2Offset(int gOffset,
|
|||
return offset;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _gather_kernel(float *in, float *out,
|
||||
template <typename dataT, typename T>
|
||||
__global__ void _gather_kernel(dataT *in, dataT *out,
|
||||
infini::GatherMetaData metaData, size_t num) {
|
||||
T tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
while (tid < num) {
|
||||
if (tid < num) {
|
||||
T offset = gatheredOffset2Offset<T>(tid, metaData);
|
||||
out[tid] = in[offset];
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
|
||||
template <typename T>
|
||||
void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.indexType == DataType::Int64) {
|
||||
_gather_kernel<int64_t>
|
||||
_gather_kernel<T, int64_t>
|
||||
<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
} else {
|
||||
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
_gather_kernel<T, int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
}
|
||||
}
|
||||
template void gather_kernel<float>(float *in, float *out,
|
||||
GatherMetaData metaData, size_t num);
|
||||
template void gather_kernel<half>(half *in, half *out, GatherMetaData metaData,
|
||||
size_t num);
|
||||
template void gather_kernel<int8_t>(int8_t *in, int8_t *out,
|
||||
GatherMetaData metaData, size_t num);
|
||||
} // namespace infini
|
||||
|
|
|
@ -21,8 +21,7 @@ class GatherElementsCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Float32,
|
||||
GatherElementsCuda, "GatherELements_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, DataType::Int32,
|
||||
GatherElementsCuda, "GatherElements_CUDA_Int32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::GatherElements, GatherElementsCuda,
|
||||
"GatherELements_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -24,8 +24,10 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
|
|||
int dimsize = dims[op->getAxis()];
|
||||
int size = op->getOutput(0)->size();
|
||||
int scaleSize = op->getInputs(1)->size();
|
||||
if (op->getDType() == DataType::Float32) {
|
||||
if (op->numInputs() == 3) {
|
||||
void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const biasData =
|
||||
(op->getInputs(2)->getRawDataPtr<void *>());
|
||||
int biasSize = op->getInputs(2)->size();
|
||||
// printf("kernel bias:true:%d\n", 1);
|
||||
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
||||
|
@ -36,10 +38,27 @@ class LayerNormCuda : public CudaKernelWithoutConfig {
|
|||
LaynormKernel((float *)inputData, (float *)scaleData, eps, size,
|
||||
scaleSize, dimsize, stride, (float *)outputData);
|
||||
}
|
||||
} else if (op->getDType() == DataType::Float16) {
|
||||
if (op->numInputs() == 3) {
|
||||
void *const biasData =
|
||||
(op->getInputs(2)->getRawDataPtr<void *>());
|
||||
int biasSize = op->getInputs(2)->size();
|
||||
// printf("kernel bias:true:%d\n", 1);
|
||||
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
|
||||
scaleSize, dimsize, stride, (half *)outputData,
|
||||
(half *)biasData, biasSize);
|
||||
} else {
|
||||
// printf("kernel bias:false:%d\n", 0);
|
||||
LaynormKernel((half *)inputData, (half *)scaleData, eps, size,
|
||||
scaleSize, dimsize, stride, (half *)outputData);
|
||||
}
|
||||
} else {
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, DataType::Float32,
|
||||
LayerNormCuda, "LayerNorm_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::LayerNormalization, LayerNormCuda,
|
||||
"LayerNorm_CUDA");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,43 +1,41 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
template <int BLOCK_DIM>
|
||||
template <typename T, int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void blockLaynormKernel(const float *input, const float *scale,
|
||||
const int dimsize, const int stride, float *output,
|
||||
const float eps, int scaleSize, const float *bias,
|
||||
int biasSize) {
|
||||
void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
|
||||
const int stride, T *output, const T eps,
|
||||
int scaleSize, const T *bias, int biasSize) {
|
||||
// len(scale) = len(bias) = dimsize
|
||||
int tmp = blockIdx.x % stride;
|
||||
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||
float muPartial = 0.0f;
|
||||
T muPartial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float mu;
|
||||
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||
__shared__ T mu;
|
||||
T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
mu = muBlock / dimsize;
|
||||
mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sigma2Partial = 0.0f;
|
||||
T sigma2Partial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||
|
||||
__shared__ float sigma2;
|
||||
float sigma2Block =
|
||||
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||
__shared__ T sigma2;
|
||||
T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
sigma2 = sigma2Block / dimsize;
|
||||
sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
}
|
||||
__syncthreads();
|
||||
if (biasSize == dimsize) {
|
||||
|
@ -47,8 +45,9 @@ __launch_bounds__(BLOCK_DIM) __global__
|
|||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
mu) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM];
|
||||
}
|
||||
} else {
|
||||
|
@ -57,8 +56,9 @@ __launch_bounds__(BLOCK_DIM) __global__
|
|||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
mu) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM];
|
||||
}
|
||||
}
|
||||
|
@ -69,8 +69,9 @@ __launch_bounds__(BLOCK_DIM) __global__
|
|||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
mu) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||
bias[0];
|
||||
}
|
||||
} else {
|
||||
|
@ -79,50 +80,50 @@ __launch_bounds__(BLOCK_DIM) __global__
|
|||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
mu) /
|
||||
sqrt(sigma2 + eps) +
|
||||
mu) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(sigma2 + eps)))) +
|
||||
bias[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//-----------------
|
||||
template <int BLOCK_DIM>
|
||||
template <typename T, int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void blockLaynormKernel(const float *input, const float *scale,
|
||||
const int dimsize, const int stride, float *output,
|
||||
const float eps, int scaleSize) {
|
||||
void blockLaynormKernel(const T *input, const T *scale, const int dimsize,
|
||||
const int stride, T *output, const T eps,
|
||||
int scaleSize) {
|
||||
// len(scale) = len(bias) = dimsize
|
||||
int tmp = blockIdx.x % stride;
|
||||
int tid = (blockIdx.x - tmp) * dimsize + tmp;
|
||||
float muPartial = 0.0f;
|
||||
T muPartial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ float mu;
|
||||
float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||
__shared__ T mu;
|
||||
T muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
mu = muBlock / dimsize;
|
||||
mu = muBlock * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sigma2Partial = 0.0f;
|
||||
T sigma2Partial = 0.0f;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu);
|
||||
}
|
||||
typedef cub::BlockReduce<float, BLOCK_DIM> BlockReduce;
|
||||
typedef cub::BlockReduce<T, BLOCK_DIM> BlockReduce;
|
||||
|
||||
__shared__ float sigma2;
|
||||
float sigma2Block =
|
||||
BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||
__shared__ T sigma2;
|
||||
T sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum());
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
sigma2 = sigma2Block / dimsize;
|
||||
sigma2 = sigma2Block * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
}
|
||||
__syncthreads();
|
||||
if (scaleSize == dimsize) {
|
||||
|
@ -130,16 +131,18 @@ __launch_bounds__(BLOCK_DIM) __global__
|
|||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
|
||||
sqrt(sigma2 + eps);
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||
static_cast<T>(
|
||||
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) /
|
||||
sqrt(sigma2 + eps);
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - mu) *
|
||||
static_cast<T>(
|
||||
__fdividef(1.0F, sqrt(static_cast<float>(sigma2 + eps))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,33 +161,33 @@ __inline__ __device__ T WarpAllReduce(T val) {
|
|||
}
|
||||
return val;
|
||||
}
|
||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void warpLaynormKernel(const T *input, const T *scale,
|
||||
const int dimsize, const int stride,
|
||||
float *output, const float eps, int scaleSize,
|
||||
int otherSize, const float *bias,
|
||||
int biasSize) {
|
||||
T *output, const T eps, int scaleSize,
|
||||
int otherSize, const T *bias, int biasSize) {
|
||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||
if (otherIdx < otherSize) {
|
||||
|
||||
__shared__ float muTotal[BLOCK_DIM_y];
|
||||
__shared__ float sigma2Total[BLOCK_DIM_y];
|
||||
__shared__ T muTotal[BLOCK_DIM_y];
|
||||
__shared__ T sigma2Total[BLOCK_DIM_y];
|
||||
|
||||
float muPartial = 0.0f;
|
||||
T muPartial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||
}
|
||||
|
||||
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
|
||||
muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
muTotal[threadIdx.y] = muPartial / dimsize;
|
||||
muTotal[threadIdx.y] =
|
||||
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
|
||||
//--------------------------------------------
|
||||
float sigma2Partial = 0.0f;
|
||||
T sigma2Partial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
|
@ -194,10 +197,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
muTotal[threadIdx.y]);
|
||||
}
|
||||
|
||||
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
|
||||
sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
|
||||
sigma2Total[threadIdx.y] =
|
||||
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
|
||||
//--------------------------------------------
|
||||
if (biasSize == dimsize) {
|
||||
|
@ -209,8 +213,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
muTotal[threadIdx.y]) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(
|
||||
sigma2Total[threadIdx.y] + eps)))) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
||||
}
|
||||
} else {
|
||||
|
@ -221,8 +227,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
scale[0] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
muTotal[threadIdx.y]) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(
|
||||
sigma2Total[threadIdx.y] + eps)))) +
|
||||
bias[threadIdx.x + ph * BLOCK_DIM_x];
|
||||
}
|
||||
}
|
||||
|
@ -235,8 +243,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
muTotal[threadIdx.y]) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(
|
||||
sigma2Total[threadIdx.y] + eps)))) +
|
||||
bias[0];
|
||||
}
|
||||
} else {
|
||||
|
@ -247,40 +257,43 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
scale[0] *
|
||||
(input[tid +
|
||||
(threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps) +
|
||||
muTotal[threadIdx.y]) *
|
||||
static_cast<T>(__fdividef(
|
||||
1.0F, sqrt(static_cast<float>(
|
||||
sigma2Total[threadIdx.y] + eps)))) +
|
||||
bias[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void warpLaynormKernel(const float *input, const float *scale,
|
||||
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void warpLaynormKernel(const T *input, const T *scale,
|
||||
const int dimsize, const int stride,
|
||||
float *output, const float eps, int scaleSize,
|
||||
T *output, const T eps, int scaleSize,
|
||||
int otherSize) {
|
||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||
if (otherIdx < otherSize) {
|
||||
|
||||
__shared__ float muTotal[BLOCK_DIM_y];
|
||||
__shared__ float sigma2Total[BLOCK_DIM_y];
|
||||
__shared__ T muTotal[BLOCK_DIM_y];
|
||||
__shared__ T sigma2Total[BLOCK_DIM_y];
|
||||
|
||||
float muPartial = 0.0f;
|
||||
T muPartial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
muPartial += input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||
}
|
||||
|
||||
muPartial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(muPartial);
|
||||
muPartial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(muPartial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
muTotal[threadIdx.y] = muPartial / dimsize;
|
||||
muTotal[threadIdx.y] =
|
||||
muPartial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
|
||||
//--------------------------------------------
|
||||
float sigma2Partial = 0.0f;
|
||||
T sigma2Partial = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
sigma2Partial +=
|
||||
|
@ -290,10 +303,11 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
muTotal[threadIdx.y]);
|
||||
}
|
||||
|
||||
sigma2Partial = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sigma2Partial);
|
||||
sigma2Partial = WarpAllReduce<SumOp, T, BLOCK_DIM_x>(sigma2Partial);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
sigma2Total[threadIdx.y] = sigma2Partial / dimsize;
|
||||
sigma2Total[threadIdx.y] =
|
||||
sigma2Partial * static_cast<T>(__fdividef(1.0F, dimsize));
|
||||
|
||||
//--------------------------------------------
|
||||
if (scaleSize == dimsize) {
|
||||
|
@ -302,8 +316,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[threadIdx.x + ph * BLOCK_DIM_x] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps);
|
||||
muTotal[threadIdx.y]) *
|
||||
static_cast<T>(
|
||||
__fdividef(1.0F, sqrt(static_cast<float>(
|
||||
sigma2Total[threadIdx.y] + eps))));
|
||||
}
|
||||
} else {
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
|
@ -311,8 +327,10 @@ __global__ void warpLaynormKernel(const float *input, const float *scale,
|
|||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
scale[0] *
|
||||
(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
muTotal[threadIdx.y]) /
|
||||
sqrt(sigma2Total[threadIdx.y] + eps);
|
||||
muTotal[threadIdx.y]) *
|
||||
static_cast<T>(
|
||||
__fdividef(1.0F, sqrt(static_cast<float>(
|
||||
sigma2Total[threadIdx.y] + eps))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -325,7 +343,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<1024>
|
||||
blockLaynormKernel<float, 1024>
|
||||
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||
eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
|
@ -335,7 +353,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
|
@ -345,7 +363,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
|
@ -355,7 +373,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
|
@ -365,7 +383,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
|
@ -378,7 +396,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<1024><<<num_block, BLOCK_DIM>>>(
|
||||
blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -387,7 +405,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<32, 32><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -396,7 +414,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<16, 64><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -405,7 +423,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<8, 128><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -414,7 +432,108 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<4, 256><<<grid_dim, block_dim>>>(
|
||||
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
//-----------------
|
||||
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
half *output, const half *bias, int biasSize) {
|
||||
int num_block = size / dimsize;
|
||||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<half, 1024>
|
||||
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output,
|
||||
eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
}
|
||||
|
||||
void LaynormKernel(const half *input, const half *scale, const half eps,
|
||||
int size, int scaleSize, const int dimsize, const int stride,
|
||||
half *output) {
|
||||
int num_block = size / dimsize;
|
||||
if (dimsize > 1024) {
|
||||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
int num_block_x = (num_block + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>(
|
||||
input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_expand.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -48,11 +49,12 @@ class matmulCublas : public Kernel {
|
|||
auto opB = op->getTransB() ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const int lda = op->getTransA() ? m : k, ldb = op->getTransB() ? k : n,
|
||||
ldc = n;
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
if (op->numInputs() == 2) { // no bias
|
||||
beta = 0.f;
|
||||
} else { // broadcast bias to output
|
||||
beta = 1.f;
|
||||
float alpha_naive = 1.f, beta_naive = 0.f;
|
||||
auto dataType = op->getDType();
|
||||
auto cuDataType = cublasDataTypeConvert(dataType);
|
||||
IT_ASSERT(cuDataType != CUDA_R_8I, "matmul don't support int8 dtype.");
|
||||
if (op->numInputs() == 3) { // have bias
|
||||
beta_naive = 1.f;
|
||||
auto inC = op->getInputs(2);
|
||||
auto out = op->getOutput();
|
||||
SmallArray inputShape, outputShape;
|
||||
|
@ -69,8 +71,9 @@ class matmulCublas : public Kernel {
|
|||
if (i >= offset)
|
||||
inputShape.data[i] = inC->getDims()[i - offset];
|
||||
}
|
||||
expandKernel(inC->getRawDataPtr<float *>(),
|
||||
out->getRawDataPtr<float *>(), nDims, outputsize,
|
||||
const int dType = dataType.getIndex();
|
||||
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
||||
out->getRawDataPtr<void *>(), nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
// TODO:use compute type
|
||||
|
@ -89,16 +92,38 @@ class matmulCublas : public Kernel {
|
|||
(dimB == 3 && op->getInputs(1)->getDims()[0] == 1))
|
||||
? 0 // Broadcast the batch dimension if batch size is 1
|
||||
: n * k;
|
||||
if (dataType == DataType::Float16) {
|
||||
half alpha_half = static_cast<half>(alpha_naive);
|
||||
half beta_half = static_cast<half>(beta_naive);
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, strideB, inAData, CUDA_R_32F, lda, strideA,
|
||||
&beta, outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
|
||||
} else {
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
} else {
|
||||
if (dataType == DataType::Float16) {
|
||||
half alpha_half = static_cast<half>(alpha_naive);
|
||||
half beta_half = static_cast<half>(beta_naive);
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_half, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_half,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
|
||||
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_naive, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_naive,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
}
|
||||
// if (stat != CUBLAS_STATUS_SUCCESS)
|
||||
// cout << cublasGetErrorString(stat);
|
||||
|
@ -140,8 +165,9 @@ class matmulCublas : public Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MatMul, DataType::Float32, matmulCublas,
|
||||
"Matmul_cuBLAS_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MatMul, matmulCublas,
|
||||
"Matmul_cuBLAS_CUDA");
|
||||
|
||||
REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -229,9 +229,8 @@ class MemboundTVMExtractSource : public Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
|
||||
// MemboundTVMExtractSource,
|
||||
// "Memobund_TVM_Ansor_extract_source");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMExtractSource,
|
||||
"Memobund_TVM_Ansor_extract_source");
|
||||
}; // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -216,9 +216,9 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
|
||||
MemboundTVMPackedFunction,
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMPackedFunction,
|
||||
"Memobund_TVM_Ansor_packed_funciton");
|
||||
|
||||
}; // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -39,10 +39,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda,
|
||||
"Slice__CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda,
|
||||
"Slice__CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
|
||||
"Pad__CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, SliceCuda, "Slice__CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pad, PadCuda, "Pad__CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/data_type.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_pad_slice.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
|
||||
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
||||
TransMetaData metaData,
|
||||
|
@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
|
||||
int nDims, int num, bool isPad) {
|
||||
__global__ void _pad_slice_kernel(void *part, void *whole,
|
||||
TransMetaData metaData, int nDims, int num,
|
||||
bool isPad) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid >= num)
|
||||
if (tid >= num) {
|
||||
return;
|
||||
}
|
||||
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
while (tid < num) {
|
||||
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
|
||||
if (isPad)
|
||||
if (offset < 0)
|
||||
whole[tid] = 0;
|
||||
else
|
||||
whole[tid] = part[offset];
|
||||
else if (offset >= 0)
|
||||
part[offset] = whole[tid];
|
||||
if (isPad) {
|
||||
if (offset < 0) {
|
||||
((T *)whole)[tid] = static_cast<T>(0.f);
|
||||
} else {
|
||||
((T *)whole)[tid] = ((T *)part)[offset];
|
||||
}
|
||||
} else if (offset >= 0) {
|
||||
((T *)part)[offset] = ((T *)whole)[tid];
|
||||
}
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
#define CASE(T) \
|
||||
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
|
||||
partData, wholeData, metadata, nDims, num, isPad);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metadata.DType == DataType::Int64.getIndex()) {
|
||||
_pad_slice_kernel<int64_t>
|
||||
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
|
||||
metadata, nDims, num, isPad);
|
||||
} else if (metadata.DType == DataType::Float32.getIndex()) {
|
||||
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
|
||||
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
|
||||
}
|
||||
int dType = metadata.DType;
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -8,6 +8,7 @@ class poolingCudnn : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
@ -76,8 +77,9 @@ class avgPoolCudnn : public poolingCudnn {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn,
|
||||
"MaxPool_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, DataType::Float32,
|
||||
avgPoolCudnn, "AvgPool_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, maxPoolCudnn,
|
||||
"MaxPool_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::AveragePool, avgPoolCudnn,
|
||||
"AvgPool_cuDNN_CUDA");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -40,8 +40,7 @@ class RecvNCCL : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Recv, DataType::Float32, RecvNCCL,
|
||||
"Recv_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Recv, RecvNCCL, "Recv_NCCL_CUDA");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
|
||||
namespace infini {
|
||||
class ReduceCudnnBase : public CudaKernelWithoutConfig {
|
||||
|
@ -46,12 +47,12 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
|
|||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
|
||||
if (nInDims > 3) {
|
||||
checkCudnnError(cudnnSetTensorNdDescriptor(
|
||||
inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray));
|
||||
checkCudnnError(
|
||||
cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims,
|
||||
outDimArray, outStrideArray));
|
||||
inDesc, cudnnDataType, nInDims, inDimArray, inStrideArray));
|
||||
checkCudnnError(cudnnSetTensorNdDescriptor(
|
||||
outDesc, cudnnDataType, nInDims, outDimArray, outStrideArray));
|
||||
} else {
|
||||
int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1};
|
||||
for (int i = 0; i < nInDims; ++i) {
|
||||
|
@ -62,20 +63,19 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
|
|||
}
|
||||
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1],
|
||||
inDesc, CUDNN_TENSOR_NCHW, cudnnDataType, idims[0], idims[1],
|
||||
idims[2], idims[3]));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0],
|
||||
odims[1], odims[2], odims[3]));
|
||||
outDesc, CUDNN_TENSOR_NCHW, cudnnDataType, odims[0], odims[1],
|
||||
odims[2], odims[3]));
|
||||
}
|
||||
|
||||
// get reduce descriptor
|
||||
cudnnReduceTensorDescriptor_t reduceDesc;
|
||||
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
|
||||
checkCudnnError(cudnnSetReduceTensorDescriptor(
|
||||
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT,
|
||||
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
|
||||
CUDNN_32BIT_INDICES));
|
||||
reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN,
|
||||
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
|
||||
|
||||
// get workspace
|
||||
size_t workspaceSize = 0;
|
||||
|
@ -120,8 +120,9 @@ class ReduceSumCudnn : public ReduceCudnnBase {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32,
|
||||
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, DataType::Float32,
|
||||
ReduceSumCudnn, "ReduceSum_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, ReduceMeanCudnn,
|
||||
"ReduceMean_cuDNN_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ReduceSum, ReduceSumCudnn,
|
||||
"ReduceSum_cuDNN_CUDA");
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -11,19 +11,12 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
// reshape/flatten/identity all act as copying from input to output.
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
|
||||
"Reshape_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda,
|
||||
"Reshape_CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
|
||||
"Reshape_CUDA_Int32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
||||
"Flatten_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, DataType::Float32, CopyCuda,
|
||||
"Squeeze_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, DataType::Float32, CopyCuda,
|
||||
"Unsqueeze_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
||||
"Identity_CUDA_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, CopyCuda, "Reshape_CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, CopyCuda, "Flatten_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, CopyCuda, "Identity_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Squeeze, CopyCuda, "Squeeze_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Unsqueeze, CopyCuda, "Unsqueeze_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -6,6 +6,7 @@ class ResizeCuda : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ResizeObj>(_op);
|
||||
IT_ASSERT(op->getDType() == DataType::Float32);
|
||||
auto in = op->getInputs(0);
|
||||
auto out = op->getOutputs()[0];
|
||||
|
||||
|
@ -48,7 +49,6 @@ class ResizeCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Resize, DataType::Float32, ResizeCuda,
|
||||
"Resize_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Resize, ResizeCuda, "Resize_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -36,8 +36,7 @@ class SendNCCL : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Send, DataType::Float32, SendNCCL,
|
||||
"Send_NCCL_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Send, SendNCCL, "Send_NCCL_CUDA");
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -20,11 +20,17 @@ class SoftmaxCuda : public CudaKernelWithoutConfig {
|
|||
int stride = op->getInputs(0)->getStride().at(op->getAxis());
|
||||
|
||||
int num_blocks = size / dimsize;
|
||||
if (op->getDType() == DataType::Float32) {
|
||||
softmax_kernel(num_blocks, (float *)input, (float *)output, size,
|
||||
dimsize, stride);
|
||||
} else if (op->getDType() == DataType::Float16) {
|
||||
softmax_kernel(num_blocks, (half *)input, (half *)output, size,
|
||||
dimsize, stride);
|
||||
} else {
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda,
|
||||
"Softmax_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, SoftmaxCuda, "Softmax_CUDA");
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
struct __align__(8) DataMaxSum { // update the global max and sum, store the
|
||||
// output at max_tmp and sum_tmp
|
||||
float max_tmp; // store max
|
||||
|
@ -16,9 +15,9 @@ __device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a,
|
|||
|
||||
return bigger;
|
||||
}
|
||||
template <int BLOCK_DIM>
|
||||
template <typename T, int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
|
||||
float *__restrict input, float *__restrict output, int size, int dimsize,
|
||||
T *__restrict input, T *__restrict output, int size, int dimsize,
|
||||
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
|
||||
// tid = i(JKS) + j(KS) + k(S) + s
|
||||
|
||||
|
@ -33,15 +32,33 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
|
|||
dms_partial.max_tmp = -__FLT_MAX__;
|
||||
dms_partial.sum_tmp = 0.0f;
|
||||
DataMaxSum dms_input;
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
int remain = dimsize % BLOCK_DIM;
|
||||
int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
|
||||
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
dms_input.max_tmp =
|
||||
input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride];
|
||||
input[tid + (threadIdx.x * step + ind) * stride];
|
||||
|
||||
dms_input.sum_tmp = 1.0f;
|
||||
dms_partial = reduce_dms_op(dms_partial,
|
||||
dms_partial =
|
||||
reduce_dms_op(dms_partial,
|
||||
dms_input); // reduce the data to one block
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
dms_input.max_tmp =
|
||||
input[tid + (remain * step +
|
||||
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||
stride];
|
||||
|
||||
dms_input.sum_tmp = 1.0f;
|
||||
dms_partial =
|
||||
reduce_dms_op(dms_partial,
|
||||
dms_input); // reduce the data to one block
|
||||
}
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ DataMaxSum dms_total;
|
||||
|
@ -53,13 +70,103 @@ __launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
|
|||
}
|
||||
__syncthreads();
|
||||
//-----------------
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) {
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] =
|
||||
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] -
|
||||
output[tid + (threadIdx.x * step + ind) * stride] =
|
||||
__expf(static_cast<float>(
|
||||
input[tid + (threadIdx.x * step + ind) * stride]) -
|
||||
dms_total.max_tmp) *
|
||||
__fdividef(1.0F, dms_total.sum_tmp);
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
|
||||
output[tid +
|
||||
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
|
||||
stride] =
|
||||
__expf(static_cast<float>(
|
||||
input[tid +
|
||||
(remain * step +
|
||||
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||
stride]) -
|
||||
dms_total.max_tmp) *
|
||||
__fdividef(1.0F, dms_total.sum_tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_DIM, int numPerThread>
|
||||
__global__ void
|
||||
_blockSoftmaxKernel(T *__restrict input, T *__restrict output, int size,
|
||||
int dimsize,
|
||||
int stride) { // if set axis = 1, inputShape=[I,J,K,S]
|
||||
// tid = i(JKS) + j(KS) + k(S) + s
|
||||
|
||||
// blockDim.x = size/dimsize = IKS
|
||||
// blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s
|
||||
|
||||
int tid =
|
||||
blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
|
||||
dimsize; // now, tid = i(JKS) + k(S) + s;
|
||||
int remain = dimsize % BLOCK_DIM;
|
||||
int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
|
||||
float dataPerThread[numPerThread];
|
||||
|
||||
DataMaxSum dms_partial;
|
||||
dms_partial.max_tmp = -__FLT_MAX__;
|
||||
dms_partial.sum_tmp = 0.0f;
|
||||
DataMaxSum dms_input;
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
dataPerThread[ind] =
|
||||
input[tid + (threadIdx.x * step + ind) * stride];
|
||||
dms_input.max_tmp = dataPerThread[ind];
|
||||
dms_input.sum_tmp = 1.0f;
|
||||
dms_partial =
|
||||
reduce_dms_op(dms_partial,
|
||||
dms_input); // reduce the data to one block
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
dataPerThread[ind] =
|
||||
input[tid + (remain * step +
|
||||
(threadIdx.x - remain) * (step - 1) + ind) *
|
||||
stride];
|
||||
dms_input.max_tmp = dataPerThread[ind];
|
||||
dms_input.sum_tmp = 1.0f;
|
||||
dms_partial =
|
||||
reduce_dms_op(dms_partial,
|
||||
dms_input); // reduce the data to one block
|
||||
}
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ DataMaxSum dms_total;
|
||||
DataMaxSum dms_block =
|
||||
BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
dms_total = dms_block;
|
||||
}
|
||||
__syncthreads();
|
||||
//-----------------
|
||||
if (threadIdx.x < remain) {
|
||||
for (int ind = 0; ind < step; ind++) {
|
||||
output[tid + (threadIdx.x * step + ind) * stride] =
|
||||
__expf(dataPerThread[ind] - dms_total.max_tmp) *
|
||||
__fdividef(1.0F, dms_total.sum_tmp);
|
||||
}
|
||||
} else {
|
||||
for (int ind = 0; ind < step - 1; ind++) {
|
||||
output[tid +
|
||||
(remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
|
||||
stride] =
|
||||
__expf(dataPerThread[ind] - dms_total.max_tmp) *
|
||||
__fdividef(1.0F, dms_total.sum_tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> struct SumOp {
|
||||
|
@ -81,14 +188,14 @@ __inline__ __device__ T WarpAllReduce(T val) {
|
|||
}
|
||||
return val;
|
||||
}
|
||||
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||
__global__ void _warpSoftmaxKernel(float *__restrict input,
|
||||
float *__restrict output, int size,
|
||||
int dimsize, int stride) {
|
||||
|
||||
template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y, int numPerThreadx>
|
||||
__global__ void _warpSoftmaxKernel(T *__restrict input, T *__restrict output,
|
||||
int size, int dimsize, int stride) {
|
||||
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int otherSize = size / dimsize;
|
||||
int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
|
||||
|
||||
float dataPerThreadx[numPerThreadx];
|
||||
if (otherIdx < otherSize) {
|
||||
|
||||
__shared__ float max_total[BLOCK_DIM_y];
|
||||
|
@ -96,9 +203,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
|
|||
float max_data = -__FLT_MAX__;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
max_data =
|
||||
max(max_data,
|
||||
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]);
|
||||
dataPerThreadx[ph] =
|
||||
input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
|
||||
max_data = max(max_data, dataPerThreadx[ph]);
|
||||
}
|
||||
|
||||
max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data);
|
||||
|
@ -110,9 +217,9 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
|
|||
float sum_data = 0.0f;
|
||||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
sum_data +=
|
||||
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
max_total[threadIdx.y]);
|
||||
dataPerThreadx[ph] =
|
||||
__expf(dataPerThreadx[ph] - max_total[threadIdx.y]);
|
||||
sum_data += dataPerThreadx[ph];
|
||||
}
|
||||
|
||||
sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data);
|
||||
|
@ -124,9 +231,7 @@ __global__ void _warpSoftmaxKernel(float *__restrict input,
|
|||
|
||||
for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) {
|
||||
output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
|
||||
__expf(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] -
|
||||
max_total[threadIdx.y]) *
|
||||
__fdividef(1.0F, sum_total[threadIdx.y]);
|
||||
dataPerThreadx[ph] * __fdividef(1.0F, sum_total[threadIdx.y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -137,10 +242,35 @@ namespace infini {
|
|||
void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
||||
int dimsize, int stride) {
|
||||
|
||||
if (dimsize > 1024) {
|
||||
if (dimsize > 1024 * 128) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<1024>
|
||||
_blockSoftmaxKernel<float, 1024>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 64) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 128>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 32) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 64>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 16) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 32>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 4) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 16>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 4>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -149,7 +279,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<32, 32>
|
||||
_warpSoftmaxKernel<float, 32, 32, 32>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -158,7 +288,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<16, 64>
|
||||
_warpSoftmaxKernel<float, 16, 64, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -167,7 +297,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<8, 128>
|
||||
_warpSoftmaxKernel<float, 8, 128, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -176,7 +306,79 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<4, 256>
|
||||
_warpSoftmaxKernel<float, 4, 256, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
//------------------
|
||||
void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
||||
int dimsize, int stride) {
|
||||
|
||||
if (dimsize > 1024 * 128) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 64) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 128>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 32) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 64>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 16) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 32>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 4) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 16>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 4>
|
||||
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 32, 32, 32>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 16, 64, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
int BLOCK_DIM_y = 128;
|
||||
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 8, 128, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
int BLOCK_DIM_y = 256;
|
||||
int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 4, 256, 2>
|
||||
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
namespace infini {
|
||||
|
||||
class CudaCompute {
|
||||
void initComposedTensorMetadata(ComposedTensorMetadata &metadata,
|
||||
template <typename T>
|
||||
void initComposedTensorMetadata(ComposedTensorMetadata<T> &metadata,
|
||||
Tensor tensor) const {
|
||||
int nDims = tensor->getRank();
|
||||
auto strides = tensor->getStride();
|
||||
|
@ -16,10 +17,10 @@ class CudaCompute {
|
|||
metadata.dimSize[i] = tensor->getDims().at(i);
|
||||
metadata.stride[i] = strides.at(i);
|
||||
}
|
||||
metadata.data = tensor->getRawDataPtr<float *>();
|
||||
metadata.data = tensor->getRawDataPtr<T *>();
|
||||
}
|
||||
|
||||
void initElementTensorMetadata(ElementTensorMetadata &metadata,
|
||||
template <typename T>
|
||||
void initElementTensorMetadata(ElementTensorMetadata<T> &metadata,
|
||||
TensorVec tensors, int idx, int dim,
|
||||
int &dimBgIdx, int &batchCounter) const {
|
||||
int nTensors = tensors.size();
|
||||
|
@ -27,7 +28,7 @@ class CudaCompute {
|
|||
++batchCounter) {
|
||||
auto tensor = tensors.at(idx + batchCounter);
|
||||
auto dimSize = tensor->getDims()[dim];
|
||||
metadata.data[batchCounter] = tensor->getRawDataPtr<float *>();
|
||||
metadata.data[batchCounter] = tensor->getRawDataPtr<T *>();
|
||||
metadata.dimBgNo[batchCounter] = dimBgIdx;
|
||||
metadata.dimSize[batchCounter] = dimSize;
|
||||
metadata.nElements[batchCounter] = tensor->size();
|
||||
|
@ -36,17 +37,17 @@ class CudaCompute {
|
|||
}
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
void do_compute(Tensor composedTensor, TensorVec elementsTensor, int dim,
|
||||
int nDims, bool isSplit) const {
|
||||
IT_ASSERT(nDims <= DIM_MAX_SIZE);
|
||||
|
||||
ComposedTensorMetadata composedMetadata;
|
||||
initComposedTensorMetadata(composedMetadata, composedTensor);
|
||||
ComposedTensorMetadata<T> composedMetadata;
|
||||
initComposedTensorMetadata<T>(composedMetadata, composedTensor);
|
||||
|
||||
int dimBgNo = 0;
|
||||
int nElemets = elementsTensor.size();
|
||||
for (int i = 0; i < nElemets; i += BATCH_SIZE) {
|
||||
ElementTensorMetadata elemMetadata;
|
||||
ElementTensorMetadata<T> elemMetadata;
|
||||
int batchCounter = 0;
|
||||
initElementTensorMetadata(elemMetadata, elementsTensor, i, dim,
|
||||
dimBgNo, batchCounter);
|
||||
|
@ -74,23 +75,38 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
|||
}
|
||||
}
|
||||
}
|
||||
do_compute(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
|
||||
false);
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
do_compute<float>(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(),
|
||||
_op->getOutput()->getRank(), false);
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
do_compute<half>(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(),
|
||||
_op->getOutput()->getRank(), false);
|
||||
} else {
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
do_compute(_op->getInputs(0), _op->getOutputs(),
|
||||
as<SplitObj>(_op)->getDim(), _op->getInputs(0)->getRank(),
|
||||
true);
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
do_compute<float>(_op->getInputs(0), _op->getOutputs(),
|
||||
as<SplitObj>(_op)->getDim(),
|
||||
_op->getInputs(0)->getRank(), true);
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
do_compute<half>(_op->getInputs(0), _op->getOutputs(),
|
||||
as<SplitObj>(_op)->getDim(),
|
||||
_op->getInputs(0)->getRank(), true);
|
||||
} else {
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Concat, DataType::Float32, ConcatCuda,
|
||||
"Concat_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Split, DataType::Float32, SplitCuda,
|
||||
"Split_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Concat, ConcatCuda, "Concat_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Split, SplitCuda, "Split_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue